refactor!: use SQLite and SQLAlchemy for database backend (#332)

* use sqlite + sqlalchemy as a database backend

* change entries getter

* page filterstate.page_size persistent

* add test for entry.id filter

* fix closing library

* fix tag search, adding field

* add field position

* add fields reordering

* use folder

* take field position into consideration

* fix adding tag

* fix test

* try to catch the correct exception, moron

* dont expunge subtags

* DRY models

* rename LibraryField, add is_default property

* remove field.position unique constraint
This commit is contained in:
yed
2024-09-09 12:06:01 +07:00
committed by GitHub
parent 67f7e4dcf9
commit e5e7b8afc6
85 changed files with 4803 additions and 6752 deletions

View File

@@ -1,51 +0,0 @@
name: PySide App Test
on: [ push, pull_request ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install system dependencies
run: |
# dont run update, it is slow
# sudo apt-get update
sudo apt-get install -y --no-install-recommends \
libxkbcommon-x11-0 \
x11-utils \
libyaml-dev \
libegl1-mesa \
libxcb-icccm4 \
libxcb-image0 \
libxcb-keysyms1 \
libxcb-randr0 \
libxcb-render-util0 \
libxcb-xinerama0 \
libopengl0 \
libxcb-cursor0 \
libpulse0
- name: Install dependencies
run: |
pip install -Ur requirements.txt
- name: Run TagStudio app and check exit code
run: |
xvfb-run --server-args="-screen 0, 1920x1200x24 -ac +extension GLX +render -noreset" python tagstudio/tag_studio.py --ci -o /tmp/
exit_code=$?
if [ $exit_code -eq 0 ]; then
echo "TagStudio ran successfully"
else
echo "TagStudio failed with exit code $exit_code"
exit 1
fi

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install mypy==1.10.0
pip install mypy==1.11.2
mkdir tagstudio/.mypy_cache
- uses: tsuyoshicho/action-mypy@v4

View File

@@ -1,6 +1,6 @@
name: pytest
on: [push, pull_request]
on: [ push, pull_request ]
jobs:
pytest:
@@ -11,12 +11,64 @@ jobs:
- name: Checkout repo
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install system dependencies
run: |
# dont run update, it is slow
# sudo apt-get update
sudo apt-get install -y --no-install-recommends \
libxkbcommon-x11-0 \
x11-utils \
libyaml-dev \
libegl1-mesa \
libxcb-icccm4 \
libxcb-image0 \
libxcb-keysyms1 \
libxcb-randr0 \
libxcb-render-util0 \
libxcb-xinerama0 \
libopengl0 \
libxcb-cursor0 \
libpulse0
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
python -m pip install --upgrade uv
uv pip install --system -r requirements.txt
uv pip install --system -r requirements-dev.txt
- name: Run tests
- name: Run pytest
run: |
pytest tagstudio/tests/
xvfb-run pytest --cov-report xml --cov=tagstudio
- name: Store coverage
uses: actions/upload-artifact@v4
with:
name: 'coverage'
path: 'coverage.xml'
coverage:
name: Check Code Coverage
runs-on: ubuntu-latest
needs: pytest
steps:
- name: Load coverage
uses: actions/download-artifact@v4
with:
name: 'coverage'
- name: Check Code Coverage
uses: yedpodtrzitko/coverage@main
with:
thresholdAll: 0.5
thresholdNew: 0.5
thresholdModified: 0.5
coverageFile: coverage.xml
token: ${{ secrets.GITHUB_TOKEN }}
sourceDir: tagstudio/src

View File

@@ -1,11 +1,20 @@
name: Ruff
on: [ push, pull_request ]
jobs:
ruff:
ruff-format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.4.2
version: 0.6.4
args: 'format --check'
ruff-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
version: 0.6.4
args: 'check'

View File

@@ -1,6 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.2
rev: v0.6.4
hooks:
- id: ruff-format
- id: ruff

View File

@@ -1,9 +1,33 @@
[tool.ruff]
exclude = ["main_window.py", "home_ui.py", "resources.py", "resources_rc.py"]
[tool.ruff.lint]
select = ["E", "F", "UP", "B", 'SIM']
ignore = ["E402", "E501", "F541"]
[tool.mypy]
strict_optional = false
disable_error_code = ["union-attr", "annotation-unchecked", "import-untyped"]
disable_error_code = ["func-returns-value", "import-untyped"]
explicit_package_bases = true
warn_unused_ignores = true
exclude = ['tests']
check_untyped_defs = true
[[tool.mypy.overrides]]
module = "tests.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "src.qt.main_window"
ignore_errors = true
[[tool.mypy.overrides]]
module = "src.qt.ui.home_ui"
ignore_errors = true
[[tool.mypy.overrides]]
module = "src.core.ts_core"
ignore_errors = true
[tool.pytest.ini_options]
#addopts = "-m 'not qt'"
qt_api = "pyside6"

View File

@@ -1,6 +1,8 @@
ruff==0.4.2
ruff==0.6.4
pre-commit==3.7.0
pytest==8.2.0
Pyinstaller==6.6.0
mypy==1.10.0
syrupy==4.6.1
mypy==1.11.2
syrupy==4.7.1
pytest-qt==4.4.0
pytest-cov==5.0.0

View File

@@ -10,3 +10,5 @@ numpy==1.26.4
rawpy==0.21.0
pillow-heif==0.16.0
chardet==5.2.0
structlog==24.4.0
SQLAlchemy==2.0.34

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,5 @@
from enum import Enum
VERSION: str = "9.3.2" # Major.Minor.Patch
VERSION_BRANCH: str = "" # Usually "" or "Pre-Release"
@@ -120,49 +122,13 @@ ALL_FILE_TYPES: list[str] = (
+ SHORTCUT_TYPES
)
BOX_FIELDS = ["tag_box", "text_box"]
TEXT_FIELDS = ["text_line", "text_box"]
DATE_FIELDS = ["datetime"]
TAG_COLORS = [
"",
"black",
"dark gray",
"gray",
"light gray",
"white",
"light pink",
"pink",
"red",
"red orange",
"orange",
"yellow orange",
"yellow",
"lime",
"light green",
"mint",
"green",
"teal",
"cyan",
"light blue",
"blue",
"blue violet",
"violet",
"purple",
"lavender",
"berry",
"magenta",
"salmon",
"auburn",
"dark brown",
"brown",
"light brown",
"blonde",
"peach",
"warm gray",
"cool gray",
"olive",
]
TAG_FAVORITE = 1
TAG_ARCHIVED = 0
class LibraryPrefs(Enum):
IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 1

View File

@@ -19,21 +19,15 @@ class Theme(str, enum.Enum):
COLOR_DISABLED_BG = "#65440D12"
class SearchMode(int, enum.Enum):
"""Operational modes for item searching."""
AND = 0
OR = 1
class OpenStatus(enum.IntEnum):
NOT_FOUND = 0
SUCCESS = 1
CORRUPTED = 2
class FieldID(int, enum.Enum):
TITLE = 0
AUTHOR = 1
ARTIST = 2
DESCRIPTION = 4
NOTES = 5
TAGS = 6
CONTENT_TAGS = 7
META_TAGS = 8
DATE_PUBLISHED = 14
SOURCE = 21
class MacroID(enum.Enum):
AUTOFILL = "autofill"
SIDECAR = "sidecar"
BUILD_URL = "build_url"
MATCH = "match"
CLEAN_URL = "clean_url"

View File

@@ -1,42 +0,0 @@
from typing import TypedDict
from typing_extensions import NotRequired
class JsonLibary(TypedDict("", {"ts-version": str})):
# "ts-version": str
tags: "list[JsonTag]"
collations: "list[JsonCollation]"
fields: list # TODO
macros: "list[JsonMacro]"
entries: "list[JsonEntry]"
ext_list: list[str]
is_exclude_list: bool
ignored_extensions: NotRequired[list[str]] # deprecated
class JsonBase(TypedDict):
id: int
class JsonTag(JsonBase, total=False):
name: str
aliases: list[str]
color: str
shorthand: str
subtag_ids: list[int]
class JsonCollation(JsonBase, total=False):
title: str
e_ids_and_pages: list[list[int]]
sort_order: str
cover_id: int
class JsonEntry(JsonBase, total=False):
filename: str
path: str
fields: list[dict] # TODO
class JsonMacro(JsonBase, total=False): ... # TODO

View File

@@ -0,0 +1 @@
from .alchemy import * # noqa

View File

@@ -0,0 +1,6 @@
from .models import Entry
from .library import Library
from .models import Tag
from .enums import ItemType
__all__ = ["Entry", "Library", "Tag", "ItemType"]

View File

@@ -0,0 +1,48 @@
from pathlib import Path
import structlog
from sqlalchemy import Dialect, Engine, String, TypeDecorator, create_engine, text
from sqlalchemy.orm import DeclarativeBase
logger = structlog.getLogger(__name__)
class PathType(TypeDecorator):
impl = String
cache_ok = True
def process_bind_param(self, value: Path, dialect: Dialect):
if value is not None:
return Path(value).as_posix()
return None
def process_result_value(self, value: str, dialect: Dialect):
if value is not None:
return Path(value)
return None
class Base(DeclarativeBase):
type_annotation_map = {Path: PathType}
def make_engine(connection_string: str) -> Engine:
return create_engine(connection_string)
def make_tables(engine: Engine) -> None:
logger.info("creating db tables")
Base.metadata.create_all(engine)
# tag IDs < 1000 are reserved
# create tag and delete it to bump the autoincrement sequence
# TODO - find a better way
with engine.connect() as conn:
conn.execute(text("INSERT INTO tags (id, name, color) VALUES (999, 'temp', 1)"))
conn.execute(text("DELETE FROM tags WHERE id = 999"))
conn.commit()
def drop_tables(engine: Engine) -> None:
logger.info("dropping db tables")
Base.metadata.drop_all(engine)

View File

@@ -0,0 +1,140 @@
import enum
from dataclasses import dataclass
from pathlib import Path
class TagColor(enum.IntEnum):
DEFAULT = 1
BLACK = 2
DARK_GRAY = 3
GRAY = 4
LIGHT_GRAY = 5
WHITE = 6
LIGHT_PINK = 7
PINK = 8
RED = 9
RED_ORANGE = 10
ORANGE = 11
YELLOW_ORANGE = 12
YELLOW = 13
LIME = 14
LIGHT_GREEN = 15
MINT = 16
GREEN = 17
TEAL = 18
CYAN = 19
LIGHT_BLUE = 20
BLUE = 21
BLUE_VIOLET = 22
VIOLET = 23
PURPLE = 24
LAVENDER = 25
BERRY = 26
MAGENTA = 27
SALMON = 28
AUBURN = 29
DARK_BROWN = 30
BROWN = 31
LIGHT_BROWN = 32
BLONDE = 33
PEACH = 34
WARM_GRAY = 35
COOL_GRAY = 36
OLIVE = 37
class SearchMode(enum.IntEnum):
"""Operational modes for item searching."""
AND = 0
OR = 1
class ItemType(enum.Enum):
ENTRY = 0
COLLATION = 1
TAG_GROUP = 2
@dataclass
class FilterState:
"""Represent a state of the Library grid view."""
# these should remain
page_index: int | None = None
page_size: int | None = None
search_mode: SearchMode = SearchMode.AND # TODO - actually implement this
# these should be erased on update
# tag name
tag: str | None = None
# tag ID
tag_id: int | None = None
# entry id
id: int | None = None
# whole path
path: Path | str | None = None
# file name
name: str | None = None
# a generic query to be parsed
query: str | None = None
def __post_init__(self):
# strip values automatically
if query := (self.query and self.query.strip()):
# parse the value
if ":" in query:
kind, _, value = query.partition(":")
else:
# default to tag search
kind, value = "tag", query
if kind == "tag_id":
self.tag_id = int(value)
elif kind == "tag":
self.tag = value
elif kind == "path":
self.path = value
elif kind == "name":
self.name = value
elif kind == "id":
self.id = int(self.id) if str(self.id).isnumeric() else self.id
else:
self.tag = self.tag and self.tag.strip()
self.tag_id = (
int(self.tag_id) if str(self.tag_id).isnumeric() else self.tag_id
)
self.path = self.path and str(self.path).strip()
self.name = self.name and self.name.strip()
self.id = int(self.id) if str(self.id).isnumeric() else self.id
if self.page_index is None:
self.page_index = 0
if self.page_size is None:
self.page_size = 500
@property
def summary(self):
"""Show query summary"""
return (
self.query or self.tag or self.name or self.tag_id or self.path or self.id
)
@property
def limit(self):
return self.page_size
@property
def offset(self):
return self.page_size * self.page_index
class FieldTypeEnum(enum.Enum):
TEXT_LINE = "Text Line"
TEXT_BOX = "Text Box"
TAGS = "Tags"
DATETIME = "Datetime"
BOOLEAN = "Checkbox"

View File

@@ -0,0 +1,176 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TYPE_CHECKING
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
from .db import Base
from .enums import FieldTypeEnum
if TYPE_CHECKING:
from .models import Entry, Tag, ValueType
class BaseField(Base):
__abstract__ = True
@declared_attr
def id(cls) -> Mapped[int]:
return mapped_column(primary_key=True, autoincrement=True)
@declared_attr
def type_key(cls) -> Mapped[str]:
return mapped_column(ForeignKey("value_type.key"))
@declared_attr
def type(cls) -> Mapped[ValueType]:
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore
@declared_attr
def entry_id(cls) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))
@declared_attr
def entry(cls) -> Mapped[Entry]:
return relationship(foreign_keys=[cls.entry_id]) # type: ignore
@declared_attr
def position(cls) -> Mapped[int]:
return mapped_column(default=0)
def __hash__(self):
return hash(self.__key())
def __key(self):
raise NotImplementedError
value: Any
class BooleanField(BaseField):
__tablename__ = "boolean_fields"
value: Mapped[bool]
def __key(self):
return (self.type, self.value)
def __eq__(self, value) -> bool:
if isinstance(value, BooleanField):
return self.__key() == value.__key()
raise NotImplementedError
class TextField(BaseField):
__tablename__ = "text_fields"
value: Mapped[str | None]
def __key(self) -> tuple:
return self.type, self.value
def __eq__(self, value) -> bool:
if isinstance(value, TextField):
return self.__key() == value.__key()
elif isinstance(value, (TagBoxField, DatetimeField)):
return False
raise NotImplementedError
class TagBoxField(BaseField):
__tablename__ = "tag_box_fields"
tags: Mapped[set[Tag]] = relationship(secondary="tag_fields")
def __key(self):
return (
self.entry_id,
self.type_key,
)
@property
def value(self) -> None:
"""For interface compatibility with other field types."""
return None
def __eq__(self, value) -> bool:
if isinstance(value, TagBoxField):
return self.__key() == value.__key()
raise NotImplementedError
class DatetimeField(BaseField):
__tablename__ = "datetime_fields"
value: Mapped[str | None]
def __key(self):
return (self.type, self.value)
def __eq__(self, value) -> bool:
if isinstance(value, DatetimeField):
return self.__key() == value.__key()
raise NotImplementedError
@dataclass
class DefaultField:
id: int
name: str
type: FieldTypeEnum
is_default: bool = field(default=False)
class _FieldID(Enum):
"""Only for bootstrapping content of DB table"""
TITLE = DefaultField(
id=0, name="Title", type=FieldTypeEnum.TEXT_LINE, is_default=True
)
AUTHOR = DefaultField(id=1, name="Author", type=FieldTypeEnum.TEXT_LINE)
ARTIST = DefaultField(id=2, name="Artist", type=FieldTypeEnum.TEXT_LINE)
URL = DefaultField(id=3, name="URL", type=FieldTypeEnum.TEXT_LINE)
DESCRIPTION = DefaultField(id=4, name="Description", type=FieldTypeEnum.TEXT_LINE)
NOTES = DefaultField(id=5, name="Notes", type=FieldTypeEnum.TEXT_BOX)
TAGS = DefaultField(id=6, name="Tags", type=FieldTypeEnum.TAGS)
TAGS_CONTENT = DefaultField(
id=7, name="Content Tags", type=FieldTypeEnum.TAGS, is_default=True
)
TAGS_META = DefaultField(
id=8, name="Meta Tags", type=FieldTypeEnum.TAGS, is_default=True
)
COLLATION = DefaultField(id=9, name="Collation", type=FieldTypeEnum.TEXT_LINE)
DATE = DefaultField(id=10, name="Date", type=FieldTypeEnum.DATETIME)
DATE_CREATED = DefaultField(id=11, name="Date Created", type=FieldTypeEnum.DATETIME)
DATE_MODIFIED = DefaultField(
id=12, name="Date Modified", type=FieldTypeEnum.DATETIME
)
DATE_TAKEN = DefaultField(id=13, name="Date Taken", type=FieldTypeEnum.DATETIME)
DATE_PUBLISHED = DefaultField(
id=14, name="Date Published", type=FieldTypeEnum.DATETIME
)
# ARCHIVED = DefaultField(id=15, name="Archived", type=CheckboxField.checkbox)
# FAVORITE = DefaultField(id=16, name="Favorite", type=CheckboxField.checkbox)
BOOK = DefaultField(id=17, name="Book", type=FieldTypeEnum.TEXT_LINE)
COMIC = DefaultField(id=18, name="Comic", type=FieldTypeEnum.TEXT_LINE)
SERIES = DefaultField(id=19, name="Series", type=FieldTypeEnum.TEXT_LINE)
MANGA = DefaultField(id=20, name="Manga", type=FieldTypeEnum.TEXT_LINE)
SOURCE = DefaultField(id=21, name="Source", type=FieldTypeEnum.TEXT_LINE)
DATE_UPLOADED = DefaultField(
id=22, name="Date Uploaded", type=FieldTypeEnum.DATETIME
)
DATE_RELEASED = DefaultField(
id=23, name="Date Released", type=FieldTypeEnum.DATETIME
)
VOLUME = DefaultField(id=24, name="Volume", type=FieldTypeEnum.TEXT_LINE)
ANTHOLOGY = DefaultField(id=25, name="Anthology", type=FieldTypeEnum.TEXT_LINE)
MAGAZINE = DefaultField(id=26, name="Magazine", type=FieldTypeEnum.TEXT_LINE)
PUBLISHER = DefaultField(id=27, name="Publisher", type=FieldTypeEnum.TEXT_LINE)
GUEST_ARTIST = DefaultField(
id=28, name="Guest Artist", type=FieldTypeEnum.TEXT_LINE
)
COMPOSER = DefaultField(id=29, name="Composer", type=FieldTypeEnum.TEXT_LINE)
COMMENTS = DefaultField(id=30, name="Comments", type=FieldTypeEnum.TEXT_LINE)

View File

@@ -0,0 +1,20 @@
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from .db import Base
class TagSubtag(Base):
__tablename__ = "tag_subtags"
parent_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)
child_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)
class TagField(Base):
__tablename__ = "tag_fields"
field_id: Mapped[int] = mapped_column(
ForeignKey("tag_box_fields.id"), primary_key=True
)
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), primary_key=True)

View File

@@ -0,0 +1,884 @@
from datetime import datetime, UTC
import shutil
from os import makedirs
from pathlib import Path
from typing import Iterator, Any, Type
from uuid import uuid4
import structlog
from sqlalchemy import (
and_,
or_,
select,
create_engine,
Engine,
func,
update,
URL,
exists,
delete,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
Session,
contains_eager,
selectinload,
make_transient,
)
from typing import TYPE_CHECKING
from .db import make_tables
from .enums import TagColor, FilterState, FieldTypeEnum
from .fields import (
DatetimeField,
TagBoxField,
TextField,
_FieldID,
BaseField,
)
from .joins import TagSubtag, TagField
from .models import Entry, Preferences, Tag, TagAlias, ValueType, Folder
from ...constants import (
LibraryPrefs,
TS_FOLDER_NAME,
TAG_ARCHIVED,
TAG_FAVORITE,
BACKUP_FOLDER_NAME,
)
if TYPE_CHECKING:
from ...utils.dupe_files import DupeRegistry
from ...utils.missing_files import MissingRegistry
LIBRARY_FILENAME: str = "ts_library.sqlite"
logger = structlog.get_logger(__name__)
import re
import unicodedata
def slugify(input_string: str) -> str:
# Convert to lowercase and normalize unicode characters
slug = unicodedata.normalize("NFKD", input_string.lower())
# Remove non-word characters (except hyphens and spaces)
slug = re.sub(r"[^\w\s-]", "", slug).strip()
# Replace spaces with hyphens
slug = re.sub(r"[-\s]+", "-", slug)
return slug
def get_default_tags() -> tuple[Tag, ...]:
archive_tag = Tag(
id=TAG_ARCHIVED,
name="Archived",
aliases={TagAlias(name="Archive")},
color=TagColor.RED,
)
favorite_tag = Tag(
id=TAG_FAVORITE,
name="Favorite",
aliases={
TagAlias(name="Favorited"),
TagAlias(name="Favorites"),
},
color=TagColor.YELLOW,
)
return archive_tag, favorite_tag
class Library:
"""Class for the Library object, and all CRUD operations made upon it."""
library_dir: Path
storage_path: Path | str
engine: Engine | None
folder: Folder | None
ignored_extensions: list[str]
missing_tracker: "MissingRegistry"
dupe_tracker: "DupeRegistry"
def open_library(
self, library_dir: Path | str, storage_path: str | None = None
) -> None:
if isinstance(library_dir, str):
library_dir = Path(library_dir)
self.library_dir = library_dir
if storage_path == ":memory:":
self.storage_path = storage_path
else:
self.verify_ts_folders(self.library_dir)
self.storage_path = self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME
connection_string = URL.create(
drivername="sqlite",
database=str(self.storage_path),
)
logger.info("opening library", connection_string=connection_string)
self.engine = create_engine(connection_string)
with Session(self.engine) as session:
make_tables(self.engine)
tags = get_default_tags()
try:
session.add_all(tags)
session.commit()
except IntegrityError:
# default tags may exist already
session.rollback()
for pref in LibraryPrefs:
try:
session.add(Preferences(key=pref.name, value=pref.value))
session.commit()
except IntegrityError:
logger.debug("preference already exists", pref=pref)
session.rollback()
for field in _FieldID:
try:
session.add(
ValueType(
key=field.name,
name=field.value.name,
type=field.value.type,
position=field.value.id,
is_default=field.value.is_default,
)
)
session.commit()
except IntegrityError:
logger.debug("ValueType already exists", field=field)
session.rollback()
# check if folder matching current path exists already
self.folder = session.scalar(
select(Folder).where(Folder.path == self.library_dir)
)
if not self.folder:
folder = Folder(
path=self.library_dir,
uuid=str(uuid4()),
)
session.add(folder)
session.expunge(folder)
session.commit()
self.folder = folder
# load ignored extensions
self.ignored_extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
@property
def default_fields(self) -> list[BaseField]:
with Session(self.engine) as session:
types = session.scalars(
select(ValueType).where(
# check if field is default
ValueType.is_default.is_(True)
)
)
return [x.as_field for x in types]
def delete_item(self, item):
logger.info("deleting item", item=item)
with Session(self.engine) as session:
session.delete(item)
session.commit()
def remove_field_tag(self, entry: Entry, tag_id: int, field_key: str) -> bool:
assert isinstance(field_key, str), f"field_key is {type(field_key)}"
with Session(self.engine) as session:
# find field matching entry and field_type
field = session.scalars(
select(TagBoxField).where(
and_(
TagBoxField.entry_id == entry.id,
TagBoxField.type_key == field_key,
)
)
).first()
if not field:
logger.error("no field found", entry=entry, field=field)
return False
try:
# find the record in `TagField` table and delete it
tag_field = session.scalars(
select(TagField).where(
and_(
TagField.tag_id == tag_id,
TagField.field_id == field.id,
)
)
).first()
if tag_field:
session.delete(tag_field)
session.commit()
return True
except IntegrityError as e:
logger.exception(e)
session.rollback()
return False
def get_entry(self, entry_id: int) -> Entry | None:
"""Load entry without joins."""
with Session(self.engine) as session:
entry = session.scalar(select(Entry).where(Entry.id == entry_id))
if not entry:
return None
session.expunge(entry)
make_transient(entry)
return entry
@property
def entries_count(self) -> int:
with Session(self.engine) as session:
return session.scalar(select(func.count(Entry.id)))
def get_entries(self, with_joins: bool = False) -> Iterator[Entry]:
"""Load entries without joins."""
with Session(self.engine) as session:
stmt = select(Entry)
if with_joins:
# load Entry with all joins and all tags
stmt = (
stmt.outerjoin(Entry.text_fields)
.outerjoin(Entry.datetime_fields)
.outerjoin(Entry.tag_box_fields)
)
stmt = stmt.options(
contains_eager(Entry.text_fields),
contains_eager(Entry.datetime_fields),
contains_eager(Entry.tag_box_fields).selectinload(TagBoxField.tags),
)
stmt = stmt.distinct()
entries = session.execute(stmt).scalars()
if with_joins:
entries = entries.unique()
for entry in entries:
yield entry
session.expunge(entry)
@property
def tags(self) -> list[Tag]:
with Session(self.engine) as session:
# load all tags and join subtags
tags_query = select(Tag).options(selectinload(Tag.subtags))
tags = session.scalars(tags_query).unique()
tags_list = list(tags)
for tag in tags_list:
session.expunge(tag)
return list(tags_list)
def verify_ts_folders(self, library_dir: Path) -> None:
"""Verify/create folders required by TagStudio."""
if library_dir is None:
raise ValueError("No path set.")
if not library_dir.exists():
raise ValueError("Invalid library directory.")
full_ts_path = library_dir / TS_FOLDER_NAME
if not full_ts_path.exists():
logger.info("creating library directory", dir=full_ts_path)
full_ts_path.mkdir(parents=True, exist_ok=True)
def add_entries(self, items: list[Entry]) -> list[int]:
"""Add multiple Entry records to the Library."""
assert items
with Session(self.engine) as session:
# add all items
session.add_all(items)
session.flush()
new_ids = [item.id for item in items]
session.expunge_all()
session.commit()
return new_ids
def remove_entries(self, entry_ids: list[int]) -> None:
"""Remove Entry items matching supplied IDs from the Library."""
with Session(self.engine) as session:
session.query(Entry).where(Entry.id.in_(entry_ids)).delete()
session.commit()
def has_path_entry(self, path: Path) -> bool:
"""Check if item with given path is in library already."""
with Session(self.engine) as session:
return session.query(exists().where(Entry.path == path)).scalar()
def search_library(
self,
search: FilterState,
) -> tuple[int, list[Entry]]:
"""Filter library by search query.
:return: number of entries matching the query and one page of results.
"""
assert isinstance(search, FilterState)
assert self.engine
with Session(self.engine, expire_on_commit=False) as session:
statement = select(Entry)
if search.tag:
statement = (
statement.join(Entry.tag_box_fields)
.join(TagBoxField.tags)
.where(
or_(
Tag.name.ilike(search.tag),
Tag.shorthand.ilike(search.tag),
)
)
)
elif search.tag_id:
statement = (
statement.join(Entry.tag_box_fields)
.join(TagBoxField.tags)
.where(Tag.id == search.tag_id)
)
elif search.id:
statement = statement.where(Entry.id == search.id)
elif search.name:
statement = select(Entry).where(
and_(
Entry.path.ilike(f"%{search.name}%"),
# dont match directory name (ie. has following slash)
~Entry.path.ilike(f"%{search.name}%/%"),
)
)
elif search.path:
statement = statement.where(Entry.path.ilike(f"%{search.path}%"))
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
if not search.id: # if `id` is set, we don't need to filter by extensions
if extensions and is_exclude_list:
statement = statement.where(
Entry.path.notilike(f"%.{','.join(extensions)}")
)
elif extensions:
statement = statement.where(
Entry.path.ilike(f"%.{','.join(extensions)}")
)
statement = statement.options(
selectinload(Entry.text_fields),
selectinload(Entry.datetime_fields),
selectinload(Entry.tag_box_fields)
.joinedload(TagBoxField.tags)
.options(selectinload(Tag.aliases), selectinload(Tag.subtags)),
)
query_count = select(func.count()).select_from(statement.alias("entries"))
count_all: int = session.execute(query_count).scalar()
statement = statement.limit(search.limit).offset(search.offset)
logger.info(
"searching library",
filter=search,
query_full=str(
statement.compile(compile_kwargs={"literal_binds": True})
),
)
entries_ = list(session.scalars(statement).unique())
session.expunge_all()
return count_all, entries_
def search_tags(
self,
search: FilterState,
) -> list[Tag]:
"""Return a list of Tag records matching the query."""
with Session(self.engine) as session:
query = select(Tag)
query = query.options(
selectinload(Tag.subtags),
selectinload(Tag.aliases),
)
if search.tag:
query = query.where(
or_(
Tag.name.ilike(search.tag),
Tag.shorthand.ilike(search.tag),
)
)
tags = session.scalars(query)
res = list(tags)
logger.info(
"searching tags",
search=search,
statement=str(query),
results=len(res),
)
session.expunge_all()
return res
def get_all_child_tag_ids(self, tag_id: int) -> list[int]:
"""Recursively traverse a Tag's subtags and return a list of all children tags."""
all_subtags: set[int] = {tag_id}
with Session(self.engine) as session:
tag = session.scalar(select(Tag).where(Tag.id == tag_id))
if tag is None:
raise ValueError(f"No tag found with id {tag_id}.")
subtag_ids = tag.subtag_ids
all_subtags.update(subtag_ids)
for sub_id in subtag_ids:
all_subtags.update(self.get_all_child_tag_ids(sub_id))
return list(all_subtags)
def update_entry_path(self, entry_id: int | Entry, path: Path) -> None:
if isinstance(entry_id, Entry):
entry_id = entry_id.id
with Session(self.engine) as session:
update_stmt = (
update(Entry)
.where(
and_(
Entry.id == entry_id,
)
)
.values(path=path)
)
session.execute(update_stmt)
session.commit()
def remove_tag_from_field(self, tag: Tag, field: TagBoxField) -> None:
with Session(self.engine) as session:
field_ = session.scalars(
select(TagBoxField).where(TagBoxField.id == field.id)
).one()
tag = session.scalars(select(Tag).where(Tag.id == tag.id)).one()
field_.tags.remove(tag)
session.add(field_)
session.commit()
def update_field_position(
self,
field_class: Type[BaseField],
field_type: str,
entry_ids: list[int] | int,
):
if isinstance(entry_ids, int):
entry_ids = [entry_ids]
with Session(self.engine) as session:
for entry_id in entry_ids:
rows = list(
session.scalars(
select(field_class)
.where(
and_(
field_class.entry_id == entry_id,
field_class.type_key == field_type,
)
)
.order_by(field_class.id)
)
)
# Reassign `order` starting from 0
for index, row in enumerate(rows):
row.position = index
session.add(row)
session.flush()
if rows:
session.commit()
def remove_entry_field(
self,
field: BaseField,
entry_ids: list[int],
) -> None:
FieldClass = type(field)
logger.info(
"remove_entry_field",
field=field,
entry_ids=entry_ids,
field_type=field.type,
cls=FieldClass,
pos=field.position,
)
with Session(self.engine) as session:
# remove all fields matching entry and field_type
delete_stmt = delete(FieldClass).where(
and_(
FieldClass.position == field.position,
FieldClass.type_key == field.type_key,
FieldClass.entry_id.in_(entry_ids),
)
)
session.execute(delete_stmt)
session.commit()
# recalculate the remaining positions
# self.update_field_position(type(field), field.type, entry_ids)
def update_entry_field(
self,
entry_ids: list[int] | int,
field: BaseField,
content: str | datetime | set[Tag],
):
if isinstance(entry_ids, int):
entry_ids = [entry_ids]
FieldClass = type(field)
with Session(self.engine) as session:
update_stmt = (
update(FieldClass)
.where(
and_(
FieldClass.position == field.position,
FieldClass.type == field.type,
FieldClass.entry_id.in_(entry_ids),
)
)
.values(value=content)
)
session.execute(update_stmt)
session.commit()
@property
def field_types(self) -> dict[str, ValueType]:
with Session(self.engine) as session:
return {x.key: x for x in session.scalars(select(ValueType)).all()}
def get_value_type(self, field_key: str) -> ValueType:
with Session(self.engine) as session:
field = session.scalar(select(ValueType).where(ValueType.key == field_key))
session.expunge(field)
return field
def add_entry_field_type(
self,
entry_ids: list[int] | int,
*,
field: ValueType | None = None,
field_id: _FieldID | str | None = None,
value: str | datetime | list[str] | None = None,
) -> bool:
logger.info(
"add_field_to_entry",
entry_ids=entry_ids,
field_type=field,
field_id=field_id,
value=value,
)
# supply only instance or ID, not both
assert bool(field) != (field_id is not None)
if isinstance(entry_ids, int):
entry_ids = [entry_ids]
if not field:
if isinstance(field_id, _FieldID):
field_id = field_id.name
field = self.get_value_type(field_id)
field_model: TextField | DatetimeField | TagBoxField
if field.type in (FieldTypeEnum.TEXT_LINE, FieldTypeEnum.TEXT_BOX):
field_model = TextField(
type_key=field.key,
value=value or "",
)
elif field.type == FieldTypeEnum.TAGS:
field_model = TagBoxField(
type_key=field.key,
)
if value:
assert isinstance(value, list)
for tag in value:
field_model.tags.add(Tag(name=tag))
elif field.type == FieldTypeEnum.DATETIME:
field_model = DatetimeField(
type_key=field.key,
value=value,
)
else:
raise NotImplementedError(f"field type not implemented: {field.type}")
with Session(self.engine) as session:
try:
for entry_id in entry_ids:
field_model.entry_id = entry_id
session.add(field_model)
session.flush()
session.commit()
except IntegrityError as e:
logger.exception(e)
session.rollback()
return False
# TODO - trigger error signal
# recalculate the positions of fields
self.update_field_position(
field_class=type(field_model),
field_type=field.key,
entry_ids=entry_ids,
)
return True
def add_tag(self, tag: Tag, subtag_ids: list[int] | None = None) -> Tag | None:
with Session(self.engine, expire_on_commit=False) as session:
try:
session.add(tag)
session.flush()
for subtag_id in subtag_ids or []:
subtag = TagSubtag(
parent_id=tag.id,
child_id=subtag_id,
)
session.add(subtag)
session.commit()
session.expunge(tag)
return tag
except IntegrityError as e:
logger.exception(e)
session.rollback()
return None
def add_field_tag(
self,
entry: Entry,
tag: Tag,
field_key: str = _FieldID.TAGS.name,
create_field: bool = False,
) -> bool:
assert isinstance(field_key, str), f"field_key is {type(field_key)}"
with Session(self.engine) as session:
# find field matching entry and field_type
field = session.scalars(
select(TagBoxField).where(
and_(
TagBoxField.entry_id == entry.id,
TagBoxField.type_key == field_key,
)
)
).first()
if not field and not create_field:
logger.error("no field found", entry=entry, field_key=field_key)
return False
try:
if not field:
field = TagBoxField(
type_key=field_key,
entry_id=entry.id,
position=0,
)
session.add(field)
session.flush()
# create record for `TagField` table
if not tag.id:
session.add(tag)
session.flush()
tag_field = TagField(
tag_id=tag.id,
field_id=field.id,
)
session.add(tag_field)
session.commit()
logger.info(
"tag added to field", tag=tag, field=field, entry_id=entry.id
)
return True
except IntegrityError as e:
logger.exception(e)
session.rollback()
return False
def save_library_backup_to_disk(self) -> Path:
assert isinstance(self.library_dir, Path)
makedirs(
str(self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME), exist_ok=True
)
filename = (
f'ts_library_backup_{datetime.now(UTC).strftime("%Y_%m_%d_%H%M%S")}.sqlite'
)
target_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME / filename
shutil.copy2(
self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME,
target_path,
)
return target_path
def get_tag(self, tag_id: int) -> Tag:
with Session(self.engine) as session:
tags_query = select(Tag).options(selectinload(Tag.subtags))
tag = session.scalar(tags_query.where(Tag.id == tag_id))
session.expunge(tag)
for subtag in tag.subtags:
session.expunge(subtag)
return tag
def add_subtag(self, base_id: int, new_tag_id: int) -> bool:
# open session and save as parent tag
with Session(self.engine) as session:
tag = TagSubtag(
parent_id=base_id,
child_id=new_tag_id,
)
try:
session.add(tag)
session.commit()
return True
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
return False
def update_tag(self, tag: Tag, subtag_ids: list[int]) -> None:
"""
Edit a Tag in the Library.
"""
# TODO - maybe merge this with add_tag?
if tag.shorthand:
tag.shorthand = slugify(tag.shorthand)
if tag.aliases:
# TODO
...
# save the tag
with Session(self.engine) as session:
try:
# update the existing tag
session.add(tag)
session.flush()
# load all tag's subtag to know which to remove
prev_subtags = session.scalars(
select(TagSubtag).where(TagSubtag.parent_id == tag.id)
).all()
for subtag in prev_subtags:
if subtag.child_id not in subtag_ids:
session.delete(subtag)
else:
# no change, remove from list
subtag_ids.remove(subtag.child_id)
# create remaining items
for subtag_id in subtag_ids:
# add new subtag
subtag = TagSubtag(
parent_id=tag.id,
child_id=subtag_id,
)
session.add(subtag)
session.commit()
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
def prefs(self, key: LibraryPrefs) -> Any:
# load given item from Preferences table
with Session(self.engine) as session:
return session.scalar(
select(Preferences).where(Preferences.key == key.name)
).value
def set_prefs(self, key: LibraryPrefs, value: Any) -> None:
# set given item in Preferences table
with Session(self.engine) as session:
# load existing preference and update value
pref = session.scalar(
select(Preferences).where(Preferences.key == key.name)
)
pref.value = value
session.add(pref)
session.commit()
# TODO - try/except
def mirror_entry_fields(self, *entries: Entry) -> None:
"""Mirror fields among multiple Entry items."""
fields = {}
# load all fields
existing_fields = {field.type_key for field in entries[0].fields}
for entry in entries:
for entry_field in entry.fields:
fields[entry_field.type_key] = entry_field
# assign the field to all entries
for entry in entries:
for field_key, field in fields.items():
if field_key not in existing_fields:
self.add_entry_field_type(
entry_ids=entry.id,
field_id=field.type_key,
value=field.value,
)

View File

@@ -0,0 +1,270 @@
from pathlib import Path
from typing import Optional
from sqlalchemy import JSON, ForeignKey, Integer, event
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .db import Base, PathType
from .enums import TagColor
from .fields import (
DatetimeField,
TagBoxField,
TextField,
FieldTypeEnum,
_FieldID,
BaseField,
BooleanField,
)
from .joins import TagSubtag
from ...constants import TAG_FAVORITE, TAG_ARCHIVED
class TagAlias(Base):
__tablename__ = "tag_aliases"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id"))
tag: Mapped["Tag"] = relationship(back_populates="aliases")
def __init__(self, name: str, tag: Optional["Tag"] = None):
self.name = name
if tag:
self.tag = tag
super().__init__()
class Tag(Base):
__tablename__ = "tags"
__table_args__ = {"sqlite_autoincrement": True}
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(unique=True)
shorthand: Mapped[str | None]
color: Mapped[TagColor]
icon: Mapped[str | None]
aliases: Mapped[set[TagAlias]] = relationship(back_populates="tag")
parent_tags: Mapped[set["Tag"]] = relationship(
secondary=TagSubtag.__tablename__,
primaryjoin="Tag.id == TagSubtag.child_id",
secondaryjoin="Tag.id == TagSubtag.parent_id",
back_populates="subtags",
)
subtags: Mapped[set["Tag"]] = relationship(
secondary=TagSubtag.__tablename__,
primaryjoin="Tag.id == TagSubtag.parent_id",
secondaryjoin="Tag.id == TagSubtag.child_id",
back_populates="parent_tags",
)
@property
def subtag_ids(self) -> list[int]:
return [tag.id for tag in self.subtags]
@property
def alias_strings(self) -> list[str]:
return [alias.name for alias in self.aliases]
def __init__(
self,
name: str,
shorthand: str | None = None,
aliases: set[TagAlias] | None = None,
parent_tags: set["Tag"] | None = None,
subtags: set["Tag"] | None = None,
icon: str | None = None,
color: TagColor = TagColor.DEFAULT,
id: int | None = None,
):
self.name = name
self.aliases = aliases or set()
self.parent_tags = parent_tags or set()
self.subtags = subtags or set()
self.color = color
self.icon = icon
self.shorthand = shorthand
assert not self.id
self.id = id
super().__init__()
def __str__(self) -> str:
return f"<Tag ID: {self.id} Name: {self.name}>"
def __repr__(self) -> str:
return self.__str__()
class Folder(Base):
__tablename__ = "folders"
# TODO - implement this
id: Mapped[int] = mapped_column(primary_key=True)
path: Mapped[Path] = mapped_column(PathType, unique=True)
uuid: Mapped[str] = mapped_column(unique=True)
class Entry(Base):
__tablename__ = "entries"
id: Mapped[int] = mapped_column(primary_key=True)
folder_id: Mapped[int] = mapped_column(ForeignKey("folders.id"))
folder: Mapped[Folder] = relationship("Folder")
path: Mapped[Path] = mapped_column(PathType, unique=True)
text_fields: Mapped[list[TextField]] = relationship(
back_populates="entry",
cascade="all, delete",
)
datetime_fields: Mapped[list[DatetimeField]] = relationship(
back_populates="entry",
cascade="all, delete",
)
tag_box_fields: Mapped[list[TagBoxField]] = relationship(
back_populates="entry",
cascade="all, delete",
)
@property
def fields(self) -> list[BaseField]:
fields: list[BaseField] = []
fields.extend(self.tag_box_fields)
fields.extend(self.text_fields)
fields.extend(self.datetime_fields)
fields = sorted(fields, key=lambda field: field.type.position)
return fields
@property
def tags(self) -> set[Tag]:
tag_set: set[Tag] = set()
for tag_box_field in self.tag_box_fields:
tag_set.update(tag_box_field.tags)
return tag_set
@property
def is_favorited(self) -> bool:
for tag_box_field in self.tag_box_fields:
if tag_box_field.type_key == _FieldID.TAGS_META.name:
for tag in tag_box_field.tags:
if tag.id == TAG_FAVORITE:
return True
return False
@property
def is_archived(self) -> bool:
for tag_box_field in self.tag_box_fields:
if tag_box_field.type_key == _FieldID.TAGS_META.name:
for tag in tag_box_field.tags:
if tag.id == TAG_ARCHIVED:
return True
return False
def __init__(
self,
path: Path,
folder: Folder,
fields: list[BaseField],
) -> None:
self.path = path
self.folder = folder
for field in fields:
if isinstance(field, TextField):
self.text_fields.append(field)
elif isinstance(field, DatetimeField):
self.datetime_fields.append(field)
elif isinstance(field, TagBoxField):
self.tag_box_fields.append(field)
else:
raise ValueError(f"Invalid field type: {field}")
def has_tag(self, tag: Tag) -> bool:
return tag in self.tags
def remove_tag(self, tag: Tag, field: TagBoxField | None = None) -> None:
"""
Removes a Tag from the Entry. If given a field index, the given Tag will
only be removed from that index. If left blank, all instances of that
Tag will be removed from the Entry.
"""
if field:
field.tags.remove(tag)
return
for tag_box_field in self.tag_box_fields:
tag_box_field.tags.remove(tag)
class ValueType(Base):
"""Define Field Types in the Library.
Example:
key: content_tags (this field is slugified `name`)
name: Content Tags (this field is human readable name)
kind: type of content (Text Line, Text Box, Tags, Datetime, Checkbox)
is_default: Should the field be present in new Entry?
order: position of the field widget in the Entry form
"""
__tablename__ = "value_type"
key: Mapped[str] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False)
type: Mapped[FieldTypeEnum] = mapped_column(default=FieldTypeEnum.TEXT_LINE)
is_default: Mapped[bool]
position: Mapped[int]
# add relations to other tables
text_fields: Mapped[list[TextField]] = relationship(
"TextField", back_populates="type"
)
datetime_fields: Mapped[list[DatetimeField]] = relationship(
"DatetimeField", back_populates="type"
)
tag_box_fields: Mapped[list[TagBoxField]] = relationship(
"TagBoxField", back_populates="type"
)
boolean_fields: Mapped[list[BooleanField]] = relationship(
"BooleanField", back_populates="type"
)
@property
def as_field(self) -> BaseField:
FieldClass = {
FieldTypeEnum.TEXT_LINE: TextField,
FieldTypeEnum.TEXT_BOX: TextField,
FieldTypeEnum.TAGS: TagBoxField,
FieldTypeEnum.DATETIME: DatetimeField,
FieldTypeEnum.BOOLEAN: BooleanField,
}
return FieldClass[self.type](
type_key=self.key,
position=self.position,
)
@event.listens_for(ValueType, "before_insert")
def slugify_field_key(mapper, connection, target):
"""Slugify the field key before inserting into the database."""
if not target.key:
from .library import slugify
target.key = slugify(target.tag)
class Preferences(Base):
__tablename__ = "preferences"
key: Mapped[str] = mapped_column(primary_key=True)
value: Mapped[dict] = mapped_column(JSON, nullable=False)

View File

@@ -0,0 +1,38 @@
BOX_FIELDS = ["tag_box", "text_box"]
TEXT_FIELDS = ["text_line", "text_box"]
DATE_FIELDS = ["datetime"]
DEFAULT_FIELDS: list[dict] = [
{"id": 0, "name": "Title", "type": "text_line"},
{"id": 1, "name": "Author", "type": "text_line"},
{"id": 2, "name": "Artist", "type": "text_line"},
{"id": 3, "name": "URL", "type": "text_line"},
{"id": 4, "name": "Description", "type": "text_box"},
{"id": 5, "name": "Notes", "type": "text_box"},
{"id": 6, "name": "Tags", "type": "tag_box"},
{"id": 7, "name": "Content Tags", "type": "tag_box"},
{"id": 8, "name": "Meta Tags", "type": "tag_box"},
{"id": 9, "name": "Collation", "type": "collation"},
{"id": 10, "name": "Date", "type": "datetime"},
{"id": 11, "name": "Date Created", "type": "datetime"},
{"id": 12, "name": "Date Modified", "type": "datetime"},
{"id": 13, "name": "Date Taken", "type": "datetime"},
{"id": 14, "name": "Date Published", "type": "datetime"},
{"id": 15, "name": "Archived", "type": "checkbox"},
{"id": 16, "name": "Favorite", "type": "checkbox"},
{"id": 17, "name": "Book", "type": "collation"},
{"id": 18, "name": "Comic", "type": "collation"},
{"id": 19, "name": "Series", "type": "collation"},
{"id": 20, "name": "Manga", "type": "collation"},
{"id": 21, "name": "Source", "type": "text_line"},
{"id": 22, "name": "Date Uploaded", "type": "datetime"},
{"id": 23, "name": "Date Released", "type": "datetime"},
{"id": 24, "name": "Volume", "type": "collation"},
{"id": 25, "name": "Anthology", "type": "collation"},
{"id": 26, "name": "Magazine", "type": "collation"},
{"id": 27, "name": "Publisher", "type": "text_line"},
{"id": 28, "name": "Guest Artist", "type": "text_line"},
{"id": 29, "name": "Composer", "type": "text_line"},
{"id": 30, "name": "Comments", "type": "text_box"},
]

View File

@@ -1,3 +1,5 @@
# type: ignore
# ruff: noqa
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
@@ -5,11 +7,12 @@
"""The Library object and related methods for TagStudio."""
import datetime
import logging
import os
import time
import traceback
import xml.etree.ElementTree as ET
import structlog
import ujson
from enum import Enum
@@ -17,15 +20,13 @@ from pathlib import Path
from typing import cast, Generator
from typing_extensions import Self
from src.core.enums import FieldID
from src.core.json_typing import JsonCollation, JsonEntry, JsonLibary, JsonTag
from .fields import DEFAULT_FIELDS, TEXT_FIELDS
from src.core.enums import OpenStatus
from src.core.utils.str import strip_punctuation
from src.core.utils.web import strip_web_protocol
from src.core.enums import SearchMode
from src.core.constants import (
BACKUP_FOLDER_NAME,
COLLAGE_FOLDER_NAME,
TEXT_FIELDS,
TS_FOLDER_NAME,
VERSION,
)
@@ -40,7 +41,7 @@ class ItemType(Enum):
TAG_GROUP = 2
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
class Entry:
@@ -94,12 +95,12 @@ class Entry:
and self.fields == __value.fields
)
def compressed_dict(self) -> JsonEntry:
def compressed_dict(self):
"""
An alternative to __dict__ that only includes fields containing
non-default data.
"""
obj: JsonEntry = {"id": self.id}
obj = {"id": self.id}
if self.filename:
obj["filename"] = str(self.filename)
if self.path:
@@ -128,7 +129,7 @@ class Entry:
if library.get_field_attr(f, "type") == "tag_box":
if field_index >= 0 and field_index == i:
t: list[int] = library.get_field_attr(f, "content")
logging.info(
logger.info(
f't:{tag_id}, i:{i}, idx:{field_index}, c:{library.get_field_attr(f, "content")}'
)
t.remove(tag_id)
@@ -142,30 +143,30 @@ class Entry:
):
# if self.fields:
# if field_index != -1:
# logging.info(f'[LIBRARY] ADD TAG to E:{self.id}, F-DI:{field_id}, F-INDEX:{field_index}')
# logger.info(f'[LIBRARY] ADD TAG to E:{self.id}, F-DI:{field_id}, F-INDEX:{field_index}')
for i, f in enumerate(self.fields):
if library.get_field_attr(f, "id") == field_id:
field_index = i
# logging.info(f'[LIBRARY] FOUND F-INDEX:{field_index}')
# logger.info(f'[LIBRARY] FOUND F-INDEX:{field_index}')
break
if field_index == -1:
library.add_field_to_entry(self.id, field_id)
# logging.info(f'[LIBRARY] USING NEWEST F-INDEX:{field_index}')
# logger.info(f'[LIBRARY] USING NEWEST F-INDEX:{field_index}')
# logging.info(list(self.fields[field_index].keys()))
# logger.info(list(self.fields[field_index].keys()))
field_id = list(self.fields[field_index].keys())[0]
# logging.info(f'Entry Field ID: {field_id}, Index: {field_index}')
# logger.info(f'Entry Field ID: {field_id}, Index: {field_index}')
tags: list[int] = self.fields[field_index][field_id]
if tag_id not in tags:
# logging.info(f'Adding Tag: {tag_id}')
# logger.info(f'Adding Tag: {tag_id}')
tags.append(tag_id)
self.fields[field_index][field_id] = sorted(
tags, key=lambda t: library.get_tag(t).display_name(library)
)
# logging.info(f'Tags: {self.fields[field_index][field_id]}')
# logger.info(f'Tags: {self.fields[field_index][field_id]}')
class Tag:
@@ -219,12 +220,12 @@ class Tag:
else:
return f"{self.name}"
def compressed_dict(self) -> JsonTag:
def compressed_dict(self):
"""
An alternative to __dict__ that only includes fields containing
non-default data.
"""
obj: JsonTag = {"id": self.id}
obj = {"id": self.id}
if self.name:
obj["name"] = self.name
if self.shorthand:
@@ -281,12 +282,12 @@ class Collation:
__value = cast(Self, __value)
return int(self.id) == int(__value.id) and self.fields == __value.fields
def compressed_dict(self) -> JsonCollation:
def compressed_dict(self):
"""
An alternative to __dict__ that only includes fields containing
non-default data.
"""
obj: JsonCollation = {"id": self.id}
obj = {"id": self.id}
if self.title:
obj["title"] = self.title
if self.e_ids_and_pages:
@@ -368,7 +369,7 @@ class Library:
# Map of every Tag ID to the index of the Tag in self.tags.
self._tag_id_to_index_map: dict[int, int] = {}
self.default_tags: list[JsonTag] = [
self.default_tags: list = [
{"id": 0, "name": "Archived", "aliases": ["Archive"], "color": "Red"},
{
"id": 1,
@@ -383,40 +384,6 @@ class Library:
# Tag(id=1, name='Favorite', shorthand='', aliases=['Favorited, Favorites, Likes, Liked, Loved'], subtags_ids=[], color='yellow'),
# ]
self.default_fields: list[dict] = [
{"id": 0, "name": "Title", "type": "text_line"},
{"id": 1, "name": "Author", "type": "text_line"},
{"id": 2, "name": "Artist", "type": "text_line"},
{"id": 3, "name": "URL", "type": "text_line"},
{"id": 4, "name": "Description", "type": "text_box"},
{"id": 5, "name": "Notes", "type": "text_box"},
{"id": 6, "name": "Tags", "type": "tag_box"},
{"id": 7, "name": "Content Tags", "type": "tag_box"},
{"id": 8, "name": "Meta Tags", "type": "tag_box"},
{"id": 9, "name": "Collation", "type": "collation"},
{"id": 10, "name": "Date", "type": "datetime"},
{"id": 11, "name": "Date Created", "type": "datetime"},
{"id": 12, "name": "Date Modified", "type": "datetime"},
{"id": 13, "name": "Date Taken", "type": "datetime"},
{"id": 14, "name": "Date Published", "type": "datetime"},
{"id": 15, "name": "Archived", "type": "checkbox"},
{"id": 16, "name": "Favorite", "type": "checkbox"},
{"id": 17, "name": "Book", "type": "collation"},
{"id": 18, "name": "Comic", "type": "collation"},
{"id": 19, "name": "Series", "type": "collation"},
{"id": 20, "name": "Manga", "type": "collation"},
{"id": 21, "name": "Source", "type": "text_line"},
{"id": 22, "name": "Date Uploaded", "type": "datetime"},
{"id": 23, "name": "Date Released", "type": "datetime"},
{"id": 24, "name": "Volume", "type": "collation"},
{"id": 25, "name": "Anthology", "type": "collation"},
{"id": 26, "name": "Magazine", "type": "collation"},
{"id": 27, "name": "Publisher", "type": "text_line"},
{"id": 28, "name": "Guest Artist", "type": "text_line"},
{"id": 29, "name": "Composer", "type": "text_line"},
{"id": 30, "name": "Comments", "type": "text_box"},
]
def create_library(self, path: Path) -> int:
"""
Creates a TagStudio library in the given directory.\n
@@ -433,7 +400,7 @@ class Library:
self.verify_ts_folders()
self.save_library_to_disk()
self.open_library(self.library_dir)
except:
except Exception:
traceback.print_exc()
return 2
@@ -443,7 +410,7 @@ class Library:
"""If '.TagStudio' is included in the path, trim the path up to it."""
path = Path(path)
paths = [x for x in [path, *path.parents] if x.stem == TS_FOLDER_NAME]
if len(paths) > 0:
if paths:
return paths[0].parent
return path
@@ -463,12 +430,12 @@ class Library:
if not os.path.isdir(full_collage_path):
os.mkdir(full_collage_path)
def verify_default_tags(self, tag_list: list[JsonTag]) -> list[JsonTag]:
def verify_default_tags(self, tag_list: list) -> list:
"""
Ensures that the default builtin tags are present in the Library's
save file. Takes in and returns the tag dictionary from the JSON file.
"""
missing: list[JsonTag] = []
missing: list = []
for dt in self.default_tags:
if dt["id"] not in [t["id"] for t in tag_list]:
@@ -479,16 +446,14 @@ class Library:
return tag_list
def open_library(self, path: str | Path) -> int:
def open_library(self, path: str | Path) -> OpenStatus:
"""
Opens a TagStudio v9+ Library.
Returns 0 if library does not exist, 1 if successfully opened, 2 if corrupted.
Open a TagStudio v9+ Library.
"""
return_code: int = 2
return_code = OpenStatus.CORRUPTED
_path: Path = self._fix_lib_path(path)
logger.info("opening library", path=_path)
if (_path / TS_FOLDER_NAME / "ts_library.json").exists():
try:
with open(
@@ -496,7 +461,7 @@ class Library:
"r",
encoding="utf-8",
) as file:
json_dump: JsonLibary = ujson.load(file)
json_dump = ujson.load(file)
self.library_dir = Path(_path)
self.verify_ts_folders()
major, minor, patch = json_dump["ts-version"].split(".")
@@ -525,7 +490,7 @@ class Library:
self.is_exclude_list = json_dump.get("is_exclude_list", True)
end_time = time.time()
logging.info(
logger.info(
f"[LIBRARY] Extension list loaded in {(end_time - start_time):.3f} seconds"
)
@@ -570,7 +535,7 @@ class Library:
self._map_tag_id_to_index(t, -1)
self._map_tag_strings_to_tag_id(t)
else:
logging.info(
logger.info(
f"[LIBRARY]Skipping Tag with duplicate ID: {tag}"
)
@@ -579,7 +544,7 @@ class Library:
self._map_tag_id_to_cluster(t)
end_time = time.time()
logging.info(
logger.info(
f"[LIBRARY] Tags loaded in {(end_time - start_time):.3f} seconds"
)
@@ -680,8 +645,8 @@ class Library:
self._map_entry_id_to_index(e, -1)
end_time = time.time()
logging.info(
f"[LIBRARY] Entries loaded in {(end_time - start_time):.3f} seconds"
logger.info(
f"[LIBRARY] Entries loaded", load_time=end_time - start_time
)
# Parse Collations -----------------------------------------
@@ -704,7 +669,7 @@ class Library:
c = Collation(
id=id,
title=title,
e_ids_and_pages=e_ids_and_pages, # type: ignore
e_ids_and_pages=e_ids_and_pages,
sort_order=sort_order,
cover_id=cover_id,
)
@@ -716,16 +681,16 @@ class Library:
self.collations.append(c)
self._map_collation_id_to_index(c, -1)
end_time = time.time()
logging.info(
logger.info(
f"[LIBRARY] Collations loaded in {(end_time - start_time):.3f} seconds"
)
return_code = 1
return_code = OpenStatus.SUCCESS
except ujson.JSONDecodeError:
logging.info("[LIBRARY][ERROR]: Empty JSON file!")
logger.info("[LIBRARY][ERROR]: Empty JSON file!")
# If the Library is loaded, continue other processes.
if return_code == 1:
if return_code == OpenStatus.SUCCESS:
(self.library_dir / TS_FOLDER_NAME).mkdir(parents=True, exist_ok=True)
self._map_filenames_to_entry_ids()
@@ -759,7 +724,7 @@ class Library:
Used in saving the library to disk.
"""
file_to_save: JsonLibary = {
file_to_save = {
"ts-version": VERSION,
"ext_list": [i for i in self.ext_list if i],
"is_exclude_list": self.is_exclude_list,
@@ -790,7 +755,7 @@ class Library:
def save_library_to_disk(self):
"""Saves the Library to disk at the default TagStudio folder location."""
logging.info(f"[LIBRARY] Saving Library to Disk...")
logger.info(f"[LIBRARY] Saving Library to Disk...")
start_time = time.time()
filename = "ts_library.json"
@@ -808,7 +773,7 @@ class Library:
)
# , indent=4 <-- How to prettyprint dump
end_time = time.time()
logging.info(
logger.info(
f"[LIBRARY] Library saved to disk in {(end_time - start_time):.3f} seconds"
)
@@ -817,7 +782,7 @@ class Library:
Saves a backup file of the Library to disk at the default TagStudio folder location.
Returns the filename used, including the date and time."""
logging.info(f"[LIBRARY] Saving Library Backup to Disk...")
logger.info(f"[LIBRARY] Saving Library Backup to Disk...")
start_time = time.time()
filename = f'ts_library_backup_{datetime.datetime.utcnow().strftime("%F_%T").replace(":", "")}.json'
@@ -835,7 +800,7 @@ class Library:
escape_forward_slashes=False,
)
end_time = time.time()
logging.info(
logger.info(
f"[LIBRARY] Library backup saved to disk in {(end_time - start_time):.3f} seconds"
)
return filename
@@ -908,7 +873,7 @@ class Library:
# print(file)
self.files_not_in_library.append(file)
except PermissionError:
logging.info(
logger.info(
f"The File/Folder {f} cannot be accessed, because it requires higher permission!"
)
end_time = time.time()
@@ -951,7 +916,7 @@ class Library:
# Remove this Entry from the Entries list.
entry = self.get_entry(entry_id)
path = entry.path / entry.filename
# logging.info(f'Removing path: {path}')
# logger.info(f'Removing path: {path}')
del self.filename_to_entry_id_map[path]
@@ -1000,9 +965,9 @@ class Library:
for k, v in registered.items():
if len(v) > 1:
self.dupe_entries.append((v[0], v[1:]))
# logging.info(f"DUPLICATE FOUND: {(v[0], v[1:])}")
# logger.info(f"DUPLICATE FOUND: {(v[0], v[1:])}")
# for id in v:
# logging.info(f"\t{(Path()/self.get_entry(id).path/self.get_entry(id).filename)}")
# logger.info(f"\t{(Path()/self.get_entry(id).path/self.get_entry(id).filename)}")
yield len(self.entries)
@@ -1014,7 +979,7 @@ class Library:
`dupe_entries = tuple(int, list[int])`
"""
logging.info("[LIBRARY] Mirroring Duplicate Entries...")
logger.info("[LIBRARY] Mirroring Duplicate Entries...")
id_to_entry_map: dict = {}
for dupe in self.dupe_entries:
@@ -1026,7 +991,7 @@ class Library:
id_to_entry_map[id] = self.get_entry(id)
self.mirror_entry_fields([dupe[0]] + dupe[1])
logging.info(
logger.info(
"[LIBRARY] Consolidating Entries... (This may take a while for larger libraries)"
)
for i, dupe in enumerate(self.dupe_entries):
@@ -1037,7 +1002,7 @@ class Library:
# takes but in a batch-friendly way here.
# NOTE: Couldn't use get_entry(id) because that relies on the
# entry's index in the list, which is currently being messed up.
logging.info(f"[LIBRARY] Removing Unneeded Entry {id}")
logger.info(f"[LIBRARY] Removing Unneeded Entry {id}")
self.entries.remove(id_to_entry_map[id])
yield i - 1 # The -1 waits for the next step to finish
@@ -1107,12 +1072,12 @@ class Library:
# pb.setLabelText(f'Deleting {i}/{len(self.lib.missing_files)} Unlinked Entries')
try:
id = self.get_entry_id_from_filepath(missing)
logging.info(f"Removing Entry ID {id}:\n\t{missing}")
logger.info(f"Removing Entry ID {id}:\n\t{missing}")
self.remove_entry(id)
# self.driver.purge_item_from_navigation(ItemType.ENTRY, id)
deleted.append(missing)
except KeyError:
logging.info(
logger.info(
f'[LIBRARY][ERROR]: "{id}" was reported as missing, but is not in the file_to_entry_id map.'
)
yield (i, id)
@@ -1288,7 +1253,10 @@ class Library:
path = Path(file)
# print(os.path.split(file))
entry = Entry(
id=self._next_entry_id, filename=path.name, path=path.parent, fields=[]
id=self._next_entry_id,
filename=path.name,
path=path.parent,
fields=[],
)
self._next_entry_id += 1
self.add_entry_to_library(entry)
@@ -1331,7 +1299,7 @@ class Library:
entries=True,
collations=True,
tag_groups=True,
search_mode=SearchMode.AND,
search_mode=0, # AND
) -> list[tuple[ItemType, int]]:
"""
Uses a search query to generate a filtered results list.
@@ -1473,7 +1441,7 @@ class Library:
if not added:
results.append((ItemType.ENTRY, entry.id))
if search_mode == SearchMode.AND: # Include all terms
if search_mode == 0: # AND # Include all terms
# For each verified, extracted Tag term.
failure_to_union_terms = False
for term in all_tag_terms:
@@ -1507,7 +1475,7 @@ class Library:
if all_tag_terms and not failure_to_union_terms:
add_entry(entry)
if search_mode == SearchMode.OR: # Include any terms
if search_mode == 1: # OR # Include any terms
# For each verified, extracted Tag term.
for term in all_tag_terms:
# Add the immediate associated Tags to the set (ex. Name, Alias hits)
@@ -1771,7 +1739,7 @@ class Library:
"""Returns a list of Field Template IDs returned from a string query."""
matches: list[int] = []
for ft in self.default_fields:
for ft in DEFAULT_FIELDS:
if ft["name"].lower().startswith(query.lower()):
matches.append(ft["id"])
@@ -2104,7 +2072,7 @@ class Library:
elif field_type == "datetime":
entry.fields.append({int(field_id): ""})
else:
logging.info(
logger.info(
f"[LIBRARY][ERROR]: Unknown field id attempted to be added to entry: {field_id}"
)
@@ -2181,8 +2149,8 @@ class Library:
Returns a field template object associated with a field ID.
The objects have "id", "name", and "type" fields.
"""
if int(field_id) < len(self.default_fields):
return self.default_fields[int(field_id)]
if int(field_id) < len(DEFAULT_FIELDS):
return DEFAULT_FIELDS[int(field_id)]
else:
return {"id": -1, "name": "Unknown Field", "type": "unknown"}

View File

@@ -1,11 +1,18 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import traceback
from enum import IntEnum
from typing import Any
from enum import Enum
import structlog
from src.core.library.alchemy.enums import TagColor
logger = structlog.get_logger(__name__)
class ColorType(int, Enum):
class ColorType(IntEnum):
PRIMARY = 0
TEXT = 1
BORDER = 2
@@ -13,71 +20,71 @@ class ColorType(int, Enum):
DARK_ACCENT = 4
_TAG_COLORS = {
"": {
TAG_COLORS: dict[TagColor, dict[ColorType, Any]] = {
TagColor.DEFAULT: {
ColorType.PRIMARY: "#1e1e1e",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#333333",
ColorType.LIGHT_ACCENT: "#FFFFFF",
ColorType.DARK_ACCENT: "#222222",
},
"black": {
TagColor.BLACK: {
ColorType.PRIMARY: "#111018",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#18171e",
ColorType.LIGHT_ACCENT: "#b7b6be",
ColorType.DARK_ACCENT: "#03020a",
},
"dark gray": {
TagColor.DARK_GRAY: {
ColorType.PRIMARY: "#24232a",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#2a2930",
ColorType.LIGHT_ACCENT: "#bdbcc4",
ColorType.DARK_ACCENT: "#07060e",
},
"gray": {
TagColor.GRAY: {
ColorType.PRIMARY: "#53525a",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#5b5a62",
ColorType.LIGHT_ACCENT: "#cbcad2",
ColorType.DARK_ACCENT: "#191820",
},
"light gray": {
TagColor.LIGHT_GRAY: {
ColorType.PRIMARY: "#aaa9b0",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#b6b4bc",
ColorType.LIGHT_ACCENT: "#cbcad2",
ColorType.DARK_ACCENT: "#191820",
},
"white": {
TagColor.WHITE: {
ColorType.PRIMARY: "#f2f1f8",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#fefeff",
ColorType.LIGHT_ACCENT: "#ffffff",
ColorType.DARK_ACCENT: "#302f36",
},
"light pink": {
TagColor.LIGHT_PINK: {
ColorType.PRIMARY: "#ff99c4",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#ffaad0",
ColorType.LIGHT_ACCENT: "#ffcbe7",
ColorType.DARK_ACCENT: "#6c2e3b",
},
"pink": {
TagColor.PINK: {
ColorType.PRIMARY: "#ff99c4",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#ffaad0",
ColorType.LIGHT_ACCENT: "#ffcbe7",
ColorType.DARK_ACCENT: "#6c2e3b",
},
"magenta": {
TagColor.MAGENTA: {
ColorType.PRIMARY: "#f6466f",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#f7587f",
ColorType.LIGHT_ACCENT: "#fba4bf",
ColorType.DARK_ACCENT: "#61152f",
},
"red": {
TagColor.RED: {
ColorType.PRIMARY: "#e22c3c",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#b21f2d",
@@ -85,35 +92,35 @@ _TAG_COLORS = {
ColorType.LIGHT_ACCENT: "#f39caa",
ColorType.DARK_ACCENT: "#440d12",
},
"red orange": {
TagColor.RED_ORANGE: {
ColorType.PRIMARY: "#e83726",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#ea4b3b",
ColorType.LIGHT_ACCENT: "#f5a59d",
ColorType.DARK_ACCENT: "#61120b",
},
"salmon": {
TagColor.SALMON: {
ColorType.PRIMARY: "#f65848",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#f76c5f",
ColorType.LIGHT_ACCENT: "#fcadaa",
ColorType.DARK_ACCENT: "#6f1b16",
},
"orange": {
TagColor.ORANGE: {
ColorType.PRIMARY: "#ed6022",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#ef7038",
ColorType.LIGHT_ACCENT: "#f7b79b",
ColorType.DARK_ACCENT: "#551e0a",
},
"yellow orange": {
TagColor.YELLOW_ORANGE: {
ColorType.PRIMARY: "#fa9a2c",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#fba94b",
ColorType.LIGHT_ACCENT: "#fdd7ab",
ColorType.DARK_ACCENT: "#66330d",
},
"yellow": {
TagColor.YELLOW: {
ColorType.PRIMARY: "#ffd63d",
ColorType.TEXT: ColorType.DARK_ACCENT,
# ColorType.BORDER: '#ffe071',
@@ -121,154 +128,154 @@ _TAG_COLORS = {
ColorType.LIGHT_ACCENT: "#fff3c4",
ColorType.DARK_ACCENT: "#754312",
},
"mint": {
TagColor.MINT: {
ColorType.PRIMARY: "#4aed90",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#79f2b1",
ColorType.LIGHT_ACCENT: "#c8fbe9",
ColorType.DARK_ACCENT: "#164f3e",
},
"lime": {
TagColor.LIME: {
ColorType.PRIMARY: "#92e649",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#b2ed72",
ColorType.LIGHT_ACCENT: "#e9f9b7",
ColorType.DARK_ACCENT: "#405516",
},
"light green": {
TagColor.LIGHT_GREEN: {
ColorType.PRIMARY: "#85ec76",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#a3f198",
ColorType.LIGHT_ACCENT: "#e7fbe4",
ColorType.DARK_ACCENT: "#2b5524",
},
"green": {
TagColor.GREEN: {
ColorType.PRIMARY: "#28bb48",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#43c568",
ColorType.LIGHT_ACCENT: "#93e2c8",
ColorType.DARK_ACCENT: "#0d3828",
},
"teal": {
TagColor.TEAL: {
ColorType.PRIMARY: "#1ad9b2",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#4de3c7",
ColorType.LIGHT_ACCENT: "#a0f3e8",
ColorType.DARK_ACCENT: "#08424b",
},
"cyan": {
TagColor.CYAN: {
ColorType.PRIMARY: "#49e4d5",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#76ebdf",
ColorType.LIGHT_ACCENT: "#bff5f0",
ColorType.DARK_ACCENT: "#0f4246",
},
"light blue": {
TagColor.LIGHT_BLUE: {
ColorType.PRIMARY: "#55bbf6",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#70c6f7",
ColorType.LIGHT_ACCENT: "#bbe4fb",
ColorType.DARK_ACCENT: "#122541",
},
"blue": {
TagColor.BLUE: {
ColorType.PRIMARY: "#3b87f0",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#4e95f2",
ColorType.LIGHT_ACCENT: "#aedbfa",
ColorType.DARK_ACCENT: "#122948",
},
"blue violet": {
TagColor.BLUE_VIOLET: {
ColorType.PRIMARY: "#5948f2",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#6258f3",
ColorType.LIGHT_ACCENT: "#9cb8fb",
ColorType.DARK_ACCENT: "#1b1649",
},
"violet": {
TagColor.VIOLET: {
ColorType.PRIMARY: "#874ff5",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#9360f6",
ColorType.LIGHT_ACCENT: "#c9b0fa",
ColorType.DARK_ACCENT: "#3a1860",
},
"purple": {
TagColor.PURPLE: {
ColorType.PRIMARY: "#bb4ff0",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#c364f2",
ColorType.LIGHT_ACCENT: "#dda7f7",
ColorType.DARK_ACCENT: "#531862",
},
"peach": {
TagColor.PEACH: {
ColorType.PRIMARY: "#f1c69c",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#f4d4b4",
ColorType.LIGHT_ACCENT: "#fbeee1",
ColorType.DARK_ACCENT: "#613f2f",
},
"brown": {
TagColor.BROWN: {
ColorType.PRIMARY: "#823216",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#8a3e22",
ColorType.LIGHT_ACCENT: "#cd9d83",
ColorType.DARK_ACCENT: "#3a1804",
},
"lavender": {
TagColor.LAVENDER: {
ColorType.PRIMARY: "#ad8eef",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#b99ef2",
ColorType.LIGHT_ACCENT: "#d5c7fa",
ColorType.DARK_ACCENT: "#492b65",
},
"blonde": {
TagColor.BLONDE: {
ColorType.PRIMARY: "#efc664",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#f3d387",
ColorType.LIGHT_ACCENT: "#faebc6",
ColorType.DARK_ACCENT: "#6d461e",
},
"auburn": {
TagColor.AUBURN: {
ColorType.PRIMARY: "#a13220",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#aa402f",
ColorType.LIGHT_ACCENT: "#d98a7f",
ColorType.DARK_ACCENT: "#3d100a",
},
"light brown": {
TagColor.LIGHT_BROWN: {
ColorType.PRIMARY: "#be5b2d",
ColorType.TEXT: ColorType.DARK_ACCENT,
ColorType.BORDER: "#c4693d",
ColorType.LIGHT_ACCENT: "#e5b38c",
ColorType.DARK_ACCENT: "#4c290e",
},
"dark brown": {
TagColor.DARK_BROWN: {
ColorType.PRIMARY: "#4c2315",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#542a1c",
ColorType.LIGHT_ACCENT: "#b78171",
ColorType.DARK_ACCENT: "#211006",
},
"cool gray": {
TagColor.COOL_GRAY: {
ColorType.PRIMARY: "#515768",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#5b6174",
ColorType.LIGHT_ACCENT: "#9ea1c3",
ColorType.DARK_ACCENT: "#181a37",
},
"warm gray": {
TagColor.WARM_GRAY: {
ColorType.PRIMARY: "#625550",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#6c5e57",
ColorType.LIGHT_ACCENT: "#c0a392",
ColorType.DARK_ACCENT: "#371d18",
},
"olive": {
TagColor.OLIVE: {
ColorType.PRIMARY: "#4c652e",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#586f36",
ColorType.LIGHT_ACCENT: "#b4c17a",
ColorType.DARK_ACCENT: "#23300e",
},
"berry": {
TagColor.BERRY: {
ColorType.PRIMARY: "#9f2aa7",
ColorType.TEXT: ColorType.LIGHT_ACCENT,
ColorType.BORDER: "#aa43b4",
@@ -278,12 +285,14 @@ _TAG_COLORS = {
}
def get_tag_color(type, color):
color = color.lower()
def get_tag_color(color_type: ColorType, color_id: TagColor) -> str:
try:
if type == ColorType.TEXT:
return get_tag_color(_TAG_COLORS[color][type], color)
else:
return _TAG_COLORS[color][type]
if color_type == ColorType.TEXT:
text_account: ColorType = TAG_COLORS[color_id][color_type]
return get_tag_color(text_account, color_id)
return TAG_COLORS[color_id][color_type]
except KeyError:
traceback.print_stack()
logger.error("Color not found", color_id=color_id)
return "#FF00FF"

View File

@@ -5,78 +5,73 @@
"""The core classes and methods of TagStudio."""
import json
import os
from pathlib import Path
from enum import Enum
from src.core.library import Entry, Library
from src.core.constants import TS_FOLDER_NAME, TEXT_FIELDS
from src.core.constants import TS_FOLDER_NAME
from src.core.library.alchemy.fields import _FieldID
from src.core.utils.missing_files import logger
class TagStudioCore:
"""
Instantiate this to establish a TagStudio session.
Holds all TagStudio session data and provides methods to manage it.
"""
def __init__(self):
self.lib: Library = Library()
def get_gdl_sidecar(self, filepath: str | Path, source: str = "") -> dict:
@classmethod
def get_gdl_sidecar(cls, filepath: Path, source: str = "") -> dict:
"""
Attempts to open and dump a Gallery-DL Sidecar sidecar file for
the filepath.\n Returns a formatted object with notable values or an
empty object if none is found.
Attempt to open and dump a Gallery-DL Sidecar file for the filepath.
Return a formatted object with notable values or an empty object if none is found.
"""
json_dump = {}
info = {}
_filepath: Path = Path(filepath)
_filepath = _filepath.parent / (_filepath.stem + ".json")
_filepath = filepath.parent / (filepath.stem + ".json")
# NOTE: This fixes an unknown (recent?) bug in Gallery-DL where Instagram sidecar
# files may be downloaded with indices starting at 1 rather than 0, unlike the posts.
# This may only occur with sidecar files that are downloaded separate from posts.
if source == "instagram":
if not _filepath.is_file():
newstem = _filepath.stem[:-16] + "1" + _filepath.stem[-15:]
_filepath = _filepath.parent / (newstem + ".json")
if source == "instagram" and not _filepath.is_file():
newstem = _filepath.stem[:-16] + "1" + _filepath.stem[-15:]
_filepath = _filepath.parent / (newstem + ".json")
logger.info(
"get_gdl_sidecar", filepath=filepath, source=source, sidecar=_filepath
)
try:
with open(_filepath, "r", encoding="utf8") as f:
with open(_filepath, encoding="utf8") as f:
json_dump = json.load(f)
if not json_dump:
return {}
if json_dump:
if source == "twitter":
info["content"] = json_dump["content"].strip()
info["date_published"] = json_dump["date"]
elif source == "instagram":
info["description"] = json_dump["description"].strip()
info["date_published"] = json_dump["date"]
elif source == "artstation":
info["title"] = json_dump["title"].strip()
info["artist"] = json_dump["user"]["full_name"].strip()
info["description"] = json_dump["description"].strip()
info["tags"] = json_dump["tags"]
# info["tags"] = [x for x in json_dump["mediums"]["name"]]
info["date_published"] = json_dump["date"]
elif source == "newgrounds":
# info["title"] = json_dump["title"]
# info["artist"] = json_dump["artist"]
# info["description"] = json_dump["description"]
info["tags"] = json_dump["tags"]
info["date_published"] = json_dump["date"]
info["artist"] = json_dump["user"].strip()
info["description"] = json_dump["description"].strip()
info["source"] = json_dump["post_url"].strip()
if source == "twitter":
info[_FieldID.DESCRIPTION] = json_dump["content"].strip()
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "instagram":
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "artstation":
info[_FieldID.TITLE] = json_dump["title"].strip()
info[_FieldID.ARTIST] = json_dump["user"]["full_name"].strip()
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.TAGS] = json_dump["tags"]
# info["tags"] = [x for x in json_dump["mediums"]["name"]]
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
elif source == "newgrounds":
# info["title"] = json_dump["title"]
# info["artist"] = json_dump["artist"]
# info["description"] = json_dump["description"]
info[_FieldID.TAGS] = json_dump["tags"]
info[_FieldID.DATE_PUBLISHED] = json_dump["date"]
info[_FieldID.ARTIST] = json_dump["user"].strip()
info[_FieldID.DESCRIPTION] = json_dump["description"].strip()
info[_FieldID.SOURCE] = json_dump["post_url"].strip()
# else:
# print(
# f'[INFO]: TagStudio does not currently support sidecar files for "{source}"')
# except FileNotFoundError:
except:
# print(
# f'[INFO]: No sidecar file found at "{os.path.normpath(file_path + ".json")}"')
pass
except Exception:
logger.exception("Error handling sidecar file.", path=_filepath)
return info
@@ -103,102 +98,86 @@ class TagStudioCore:
# # # print("Could not resolve URL.")
# # pass
def match_conditions(self, entry_id: int) -> None:
"""Matches defined conditions against a file to add Entry data."""
@classmethod
def match_conditions(cls, lib: Library, entry_id: int) -> bool:
"""Match defined conditions against a file to add Entry data."""
cond_file = self.lib.library_dir / TS_FOLDER_NAME / "conditions.json"
# TODO - what even is this file format?
# TODO: Make this stored somewhere better instead of temporarily in this JSON file.
entry: Entry = self.lib.get_entry(entry_id)
cond_file = lib.library_dir / TS_FOLDER_NAME / "conditions.json"
if not cond_file.is_file():
return False
entry: Entry = lib.get_entry(entry_id)
try:
if cond_file.is_file():
with open(cond_file, "r", encoding="utf8") as f:
json_dump = json.load(f)
for c in json_dump["conditions"]:
match: bool = False
for path_c in c["path_conditions"]:
if str(Path(path_c).resolve()) in str(entry.path):
match = True
break
if match:
if fields := c.get("fields"):
for field in fields:
field_id = self.lib.get_field_attr(field, "id")
content = field[field_id]
with open(cond_file, encoding="utf8") as f:
json_dump = json.load(f)
for c in json_dump["conditions"]:
match: bool = False
for path_c in c["path_conditions"]:
if Path(path_c).is_relative_to(entry.path):
match = True
break
if (
self.lib.get_field_obj(int(field_id))["type"]
== "tag_box"
):
existing_fields: list[int] = (
self.lib.get_field_index_in_entry(
entry, field_id
)
)
if existing_fields:
self.lib.update_entry_field(
entry_id,
existing_fields[0],
content,
"append",
)
else:
self.lib.add_field_to_entry(
entry_id, field_id
)
self.lib.update_entry_field(
entry_id, -1, content, "append"
)
if not match:
return False
if (
self.lib.get_field_obj(int(field_id))["type"]
in TEXT_FIELDS
):
if not self.lib.does_field_content_exist(
entry_id, field_id, content
):
self.lib.add_field_to_entry(
entry_id, field_id
)
self.lib.update_entry_field(
entry_id, -1, content, "replace"
)
except:
print("Error in match_conditions...")
# input()
pass
if not c.get("fields"):
return False
def build_url(self, entry_id: int, source: str):
"""Tries to rebuild a source URL given a specific filename structure."""
fields = c["fields"]
entry_field_types = {
field.type_key: field for field in entry.fields
}
for field in fields:
is_new = field["id"] not in entry_field_types
field_key = field["id"]
if is_new:
lib.add_entry_field_type(
entry.id, field_key, field["value"]
)
else:
lib.update_entry_field(entry.id, field_key, field["value"])
except Exception:
logger.exception("Error matching conditions.", entry=entry)
return False
@classmethod
def build_url(cls, entry: Entry, source: str):
"""Try to rebuild a source URL given a specific filename structure."""
source = source.lower().replace("-", " ").replace("_", " ")
if "twitter" in source:
return self._build_twitter_url(entry_id)
return cls._build_twitter_url(entry)
elif "instagram" in source:
return self._build_instagram_url(entry_id)
return cls._build_instagram_url(entry)
def _build_twitter_url(self, entry_id: int):
@classmethod
def _build_twitter_url(cls, entry: Entry):
"""
Builds an Twitter URL given a specific filename structure.
Build a Twitter URL given a specific filename structure.
Method expects filename to be formatted as 'USERNAME_TWEET-ID_INDEX_YEAR-MM-DD'
"""
try:
entry = self.lib.get_entry(entry_id)
stubs = str(entry.filename).rsplit("_", 3)
# print(stubs)
# source, author = os.path.split(entry.path)
stubs = str(entry.path.name).rsplit("_", 3)
url = f"www.twitter.com/{stubs[0]}/status/{stubs[-3]}/photo/{stubs[-2]}"
return url
except:
except Exception:
logger.exception("Error building Twitter URL.", entry=entry)
return ""
def _build_instagram_url(self, entry_id: int):
@classmethod
def _build_instagram_url(cls, entry: Entry):
"""
Builds an Instagram URL given a specific filename structure.
Build an Instagram URL given a specific filename structure.
Method expects filename to be formatted as 'USERNAME_POST-ID_INDEX_YEAR-MM-DD'
"""
try:
entry = self.lib.get_entry(entry_id)
stubs = str(entry.filename).rsplit("_", 2)
stubs = str(entry.path.name).rsplit("_", 2)
# stubs[0] = stubs[0].replace(f"{author}_", '', 1)
# print(stubs)
# NOTE: Both Instagram usernames AND their ID can have underscores in them,
@@ -207,5 +186,6 @@ class TagStudioCore:
# seems to more or less be the case... for now...
url = f"www.instagram.com/p/{stubs[-3][-11:]}"
return url
except:
except Exception:
logger.exception("Error building Instagram URL.", entry=entry)
return ""

View File

@@ -0,0 +1,83 @@
from dataclasses import dataclass, field
from pathlib import Path
import xml.etree.ElementTree as ET
import structlog
from src.core.library import Library, Entry
from src.core.library.alchemy.enums import FilterState
logger = structlog.get_logger()
@dataclass
class DupeRegistry:
"""State handler for DupeGuru results."""
library: Library
groups: list[list[Entry]] = field(default_factory=list)
@property
def groups_count(self) -> int:
return len(self.groups)
def refresh_dupe_files(self, results_filepath: str | Path):
"""
Refresh the list of duplicate files.
A duplicate file is defined as an identical or near-identical file as determined
by a DupeGuru results file.
"""
library_dir = self.library.library_dir
if not isinstance(results_filepath, Path):
results_filepath = Path(results_filepath)
if not results_filepath.is_file():
raise ValueError("invalid file path")
self.groups.clear()
tree = ET.parse(results_filepath)
root = tree.getroot()
for group in root:
# print(f'-------------------- Match Group {i}---------------------')
files: list[Entry] = []
for element in group:
if element.tag == "file":
file_path = Path(element.attrib.get("path"))
try:
path_relative = file_path.relative_to(library_dir)
except ValueError:
# The file is not in the library directory
continue
_, entries = self.library.search_library(
FilterState(path=path_relative),
)
if not entries:
# file not in library
continue
files.append(entries[0])
if not len(files) > 1:
# only one file in the group, nothing to do
continue
self.groups.append(files)
def merge_dupe_entries(self):
"""
Merge the duplicate Entry items.
A duplicate Entry is defined as an Entry pointing to a file that one or more other Entries are also pointing to
"""
logger.info(
"Consolidating Entries... (This may take a while for larger libraries)",
groups=len(self.groups),
)
for i, entries in enumerate(self.groups):
remove_ids = [x.id for x in entries[1:]]
logger.info("Removing entries group", ids=remove_ids)
self.library.remove_entries(remove_ids)
yield i - 1 # The -1 waits for the next step to finish

View File

@@ -1,11 +0,0 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
def clean_folder_name(folder_name: str) -> str:
cleaned_name = folder_name
invalid_chars = '<>:"/\\|?*.'
for char in invalid_chars:
cleaned_name = cleaned_name.replace(char, "_")
return cleaned_name

View File

@@ -0,0 +1,71 @@
from collections.abc import Iterator
from dataclasses import field, dataclass
from pathlib import Path
import structlog
from src.core.library import Library, Entry
IGNORE_ITEMS = [
"$recycle.bin",
]
logger = structlog.get_logger()
@dataclass
class MissingRegistry:
"""State tracker for unlinked and moved files."""
library: Library
files_fixed_count: int = 0
missing_files: list[Entry] = field(default_factory=list)
@property
def missing_files_count(self) -> int:
return len(self.missing_files)
def refresh_missing_files(self) -> Iterator[int]:
"""Track the number of Entries that point to an invalid file path."""
logger.info("refresh_missing_files running")
self.missing_files = []
for i, entry in enumerate(self.library.get_entries()):
full_path = self.library.library_dir / entry.path
if not full_path.exists() or not full_path.is_file():
self.missing_files.append(entry)
yield i
def match_missing_file(self, match_item: Entry) -> list[Path]:
"""
Try to find missing entry files within the library directory.
Works if files were just moved to different subfolders and don't have duplicate names.
"""
matches = []
for item in self.library.library_dir.glob(f"**/{match_item.path.name}"):
if item.name == match_item.path.name: # TODO - implement IGNORE_ITEMS
new_path = Path(item).relative_to(self.library.library_dir)
matches.append(new_path)
return matches
def fix_missing_files(self) -> Iterator[int]:
"""Attempt to fix missing files by finding a match in the library directory."""
self.files_fixed_count = 0
for i, entry in enumerate(self.missing_files, start=1):
item_matches = self.match_missing_file(entry)
if len(item_matches) == 1:
logger.info("fix_missing_files", entry=entry, item_matches=item_matches)
self.library.update_entry_path(entry.id, item_matches[0])
self.files_fixed_count += 1
# remove fixed file
self.missing_files.remove(entry)
yield i
def execute_deletion(self) -> Iterator[int]:
for i, missing in enumerate(self.missing_files, start=1):
# TODO - optimize this by removing multiple entries at once
self.library.remove_entries([missing.id])
yield i
self.missing_files = []

View File

@@ -0,0 +1,73 @@
import time
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from src.core.constants import TS_FOLDER_NAME
from src.core.library import Library, Entry
@dataclass
class RefreshDirTracker:
library: Library
dir_file_count: int = 0
files_not_in_library: list[Path] = field(default_factory=list)
@property
def files_count(self) -> int:
return len(self.files_not_in_library)
def save_new_files(self) -> Iterator[int]:
"""Save the list of files that are not in the library."""
if not self.files_not_in_library:
yield 0
for idx, entry_path in enumerate(self.files_not_in_library):
self.library.add_entries(
[
Entry(
path=entry_path,
folder=self.library.folder,
fields=self.library.default_fields,
)
]
)
yield idx
self.files_not_in_library = []
def refresh_dir(self) -> Iterator[int]:
"""Scan a directory for files, and add those relative filenames to internal variables."""
if self.library.folder is None:
raise ValueError("No folder set.")
start_time = time.time()
self.files_not_in_library = []
self.dir_file_count = 0
lib_path = self.library.folder.path
for path in lib_path.glob("**/*"):
str_path = str(path)
if (
path.is_dir()
or "$RECYCLE.BIN" in str_path
or TS_FOLDER_NAME in str_path
or "tagstudio_thumbs" in str_path
):
continue
suffix = path.suffix.lower().lstrip(".")
if suffix in self.library.ignored_extensions:
continue
self.dir_file_count += 1
relative_path = path.relative_to(lib_path)
# TODO - load these in batch somehow
if not self.library.has_path_entry(relative_path):
self.files_not_in_library.append(relative_path)
end_time = time.time()
# Yield output every 1/30 of a second
if (end_time - start_time) > 0.034:
yield self.dir_file_count

View File

@@ -72,14 +72,12 @@ class FlowLayout(QLayout):
return height
def setGeometry(self, rect):
super(FlowLayout, self).setGeometry(rect)
super().setGeometry(rect)
self._do_layout(rect, False)
def setGridEfficiency(self, bool):
"""
Enables or Disables efficiencies when all objects are equally sized.
"""
self.grid_efficiency = bool
def setGridEfficiency(self, value: bool):
"""Enable or Disable efficiencies when all objects are equally sized."""
self.grid_efficiency = value
def sizeHint(self):
return self.minimumSize()
@@ -101,27 +99,29 @@ class FlowLayout(QLayout):
)
return size
def _do_layout(self, rect, test_only):
def _do_layout(self, rect: QRect, test_only: bool) -> float:
x = rect.x()
y = rect.y()
line_height = 0
spacing = self.spacing()
item = None
style = None
layout_spacing_x = None
layout_spacing_y = None
if self.grid_efficiency:
if self._item_list:
item = self._item_list[0]
style = item.widget().style()
layout_spacing_x = style.layoutSpacing(
QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Horizontal
)
layout_spacing_y = style.layoutSpacing(
QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Vertical
)
for i, item in enumerate(self._item_list):
if self.grid_efficiency and self._item_list:
item = self._item_list[0]
style = item.widget().style()
layout_spacing_x = style.layoutSpacing(
QSizePolicy.ControlType.PushButton,
QSizePolicy.ControlType.PushButton,
Qt.Orientation.Horizontal,
)
layout_spacing_y = style.layoutSpacing(
QSizePolicy.ControlType.PushButton,
QSizePolicy.ControlType.PushButton,
Qt.Orientation.Vertical,
)
for item in self._item_list:
# print(issubclass(type(item.widget()), FlowWidget))
# print(item.widget().ignore_size)
skip_count = 0
@@ -139,10 +139,14 @@ class FlowLayout(QLayout):
if not self.grid_efficiency:
style = item.widget().style()
layout_spacing_x = style.layoutSpacing(
QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Horizontal
QSizePolicy.ControlType.PushButton,
QSizePolicy.ControlType.PushButton,
Qt.Orientation.Horizontal,
)
layout_spacing_y = style.layoutSpacing(
QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Vertical
QSizePolicy.ControlType.PushButton,
QSizePolicy.ControlType.PushButton,
Qt.Orientation.Vertical,
)
space_x = spacing + layout_spacing_x
space_y = spacing + layout_spacing_y

View File

@@ -2,22 +2,18 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import os
import subprocess
import shutil
import sys
import traceback
from pathlib import Path
import structlog
from PySide6.QtWidgets import QLabel
from PySide6.QtCore import Qt
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
def open_file(path: str | Path, file_manager: bool = False):
@@ -28,14 +24,15 @@ def open_file(path: str | Path, file_manager: bool = False):
file_manager (bool, optional): Whether to open the file in the file manager (e.g. Finder on macOS).
Defaults to False.
"""
_path = str(path)
logging.info(f"Opening file: {_path}")
if not os.path.exists(_path):
logging.error(f"File not found: {_path}")
path = Path(path)
logger.info("Opening file", path=path)
if not path.exists():
logger.error("File not found", path=path)
return
try:
if sys.platform == "win32":
normpath = os.path.normpath(_path)
normpath = Path(path).resolve().as_posix()
if file_manager:
command_name = "explorer"
command_args = '/select,"' + normpath + '"'
@@ -61,7 +58,7 @@ def open_file(path: str | Path, file_manager: bool = False):
else:
if sys.platform == "darwin":
command_name = "open"
command_args = [_path]
command_args = [str(path)]
if file_manager:
# will reveal in Finder
command_args.append("-R")
@@ -75,18 +72,20 @@ def open_file(path: str | Path, file_manager: bool = False):
"--type=method_call",
"/org/freedesktop/FileManager1",
"org.freedesktop.FileManager1.ShowItems",
f"array:string:file://{_path}",
f"array:string:file://{str(path)}",
"string:",
]
else:
command_name = "xdg-open"
command_args = [_path]
command_args = [str(path)]
command = shutil.which(command_name)
if command is not None:
subprocess.Popen([command] + command_args, close_fds=True)
else:
logging.info(f"Could not find {command_name} on system PATH")
except:
logger.info(
"Could not find command on system PATH", command=command_name
)
except Exception:
traceback.print_exc()
@@ -144,9 +143,9 @@ class FileOpenerLabel(QLabel):
"""
super().mousePressEvent(event)
if event.button() == Qt.LeftButton:
if event.button() == Qt.MouseButton.LeftButton:
opener = FileOpenerHelper(self.filepath)
opener.open_explorer()
elif event.button() == Qt.RightButton:
elif event.button() == Qt.MouseButton.RightButton:
# Show context menu
pass

View File

@@ -1,14 +1,13 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
from collections.abc import Callable
from PySide6.QtCore import Signal, QObject
from typing import Callable
class FunctionIterator(QObject):
"""Iterates over a yielding function and emits progress as the 'value' signal.\n\nThread-Safe Guarantee™"""
"""Iterate over a yielding function and emit progress as the 'value' signal."""
value = Signal(object)

View File

@@ -36,7 +36,7 @@ class Ui_MainWindow(QMainWindow):
def __init__(self, driver: "QtDriver", parent=None) -> None:
super().__init__(parent)
self.driver: "QtDriver" = driver
self.driver = driver
self.setupUi(self)
# NOTE: These are old attempts to allow for a translucent/acrylic
@@ -235,4 +235,4 @@ class Ui_MainWindow(QMainWindow):
else:
self.landing_widget.setHidden(True)
self.landing_widget.set_status_label("")
self.scrollArea.setHidden(False)
self.scrollArea.setHidden(False)

View File

@@ -10,23 +10,24 @@ from PySide6.QtWidgets import (
QHBoxLayout,
QLabel,
QPushButton,
QComboBox,
QListWidget,
QListWidgetItem,
)
from src.core.library import Library
class AddFieldModal(QWidget):
done = Signal(int)
done = Signal(list)
def __init__(self, library: "Library"):
def __init__(self, library: Library):
# [Done]
# - OR -
# [Cancel] [Save]
super().__init__()
self.is_connected = False
self.lib = library
self.setWindowTitle(f"Add Field")
self.setWindowTitle("Add Field")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumSize(400, 300)
self.root_layout = QVBoxLayout(self)
@@ -43,17 +44,7 @@ class AddFieldModal(QWidget):
self.title_widget.setText("Add Field")
self.title_widget.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.combo_box = QComboBox()
self.combo_box.setEditable(False)
# self.combo_box.setMaxVisibleItems(5)
self.combo_box.setStyleSheet("combobox-popup:0;")
self.combo_box.view().setVerticalScrollBarPolicy(
Qt.ScrollBarPolicy.ScrollBarAsNeeded
)
for df in self.lib.default_fields:
self.combo_box.addItem(
f'{df["name"]} ({df["type"].replace("_", " ").title()})'
)
self.list_widget = QListWidget()
self.button_container = QWidget()
self.button_layout = QHBoxLayout(self.button_container)
@@ -75,17 +66,25 @@ class AddFieldModal(QWidget):
self.save_button.setDefault(True)
self.save_button.clicked.connect(self.hide)
self.save_button.clicked.connect(
lambda: self.done.emit(self.combo_box.currentIndex())
lambda: (
# get userData for each selected item
self.done.emit(self.list_widget.selectedItems())
)
)
# self.save_button.clicked.connect(lambda: save_callback(widget.get_content()))
self.button_layout.addWidget(self.save_button)
# self.returnPressed.connect(lambda: self.done.emit(self.combo_box.currentIndex()))
# self.done.connect(lambda x: callback(x))
self.root_layout.addWidget(self.title_widget)
self.root_layout.addWidget(self.combo_box)
self.root_layout.addWidget(self.list_widget)
# self.root_layout.setStretch(1,2)
self.root_layout.addStretch(1)
self.root_layout.addWidget(self.button_container)
def show(self):
self.list_widget.clear()
for df in self.lib.field_types.values():
item = QListWidgetItem(f"{df.name} ({df.type.value})")
item.setData(Qt.ItemDataRole.UserRole, df.key)
self.list_widget.addItem(item)
super().show()

View File

@@ -3,8 +3,7 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import structlog
from PySide6.QtCore import Signal, Qt
from PySide6.QtWidgets import (
QWidget,
@@ -18,30 +17,26 @@ from PySide6.QtWidgets import (
QComboBox,
)
from src.core.library import Library, Tag
from src.core.library import Tag, Library
from src.core.library.alchemy.enums import TagColor
from src.core.palette import ColorType, get_tag_color
from src.core.constants import TAG_COLORS
from src.qt.widgets.panel import PanelWidget, PanelModal
from src.qt.widgets.tag import TagWidget
from src.qt.modals.tag_search import TagSearchPanel
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
class BuildTagPanel(PanelWidget):
on_edit = Signal(Tag)
def __init__(self, library, tag_id: int = -1):
def __init__(self, library: Library, tag: Tag | None = None):
super().__init__()
self.lib: Library = library
self.lib = library
# self.callback = callback
# self.tag_id = tag_id
self.tag = None
self.setMinimumSize(300, 400)
self.root_layout = QVBoxLayout(self)
self.root_layout.setContentsMargins(6, 0, 6, 0)
@@ -95,6 +90,7 @@ class BuildTagPanel(PanelWidget):
self.subtags_layout.setContentsMargins(0, 0, 0, 0)
self.subtags_layout.setSpacing(0)
self.subtags_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.subtags_title = QLabel()
self.subtags_title.setText("Parent Tags")
self.subtags_layout.addWidget(self.subtags_title)
@@ -140,15 +136,18 @@ class BuildTagPanel(PanelWidget):
self.color_field.setEditable(False)
self.color_field.setMaxVisibleItems(10)
self.color_field.setStyleSheet("combobox-popup:0;")
for color in TAG_COLORS:
self.color_field.addItem(color.title())
for color in TagColor:
self.color_field.addItem(color.name, userData=color.value)
# self.color_field.setProperty("appearance", "flat")
self.color_field.currentTextChanged.connect(
lambda c: self.color_field.setStyleSheet(f"""combobox-popup:0;
font-weight:600;
color:{get_tag_color(ColorType.TEXT, c.lower())};
background-color:{get_tag_color(ColorType.PRIMARY, c.lower())};
""")
self.color_field.currentIndexChanged.connect(
lambda c: (
self.color_field.setStyleSheet(
"combobox-popup:0;"
"font-weight:600;"
f"color:{get_tag_color(ColorType.TEXT, self.color_field.currentData())};"
f"background-color:{get_tag_color(ColorType.PRIMARY, self.color_field.currentData())};"
)
)
)
self.color_layout.addWidget(self.color_field)
@@ -160,86 +159,59 @@ class BuildTagPanel(PanelWidget):
self.root_layout.addWidget(self.color_widget)
# self.parent().done.connect(self.update_tag)
if tag_id >= 0:
self.tag = self.lib.get_tag(tag_id)
else:
self.tag = Tag(-1, "New Tag", "", [], [], "")
self.set_tag(self.tag)
# TODO - fill subtags
self.subtags: set[int] = set()
self.set_tag(tag or Tag(name="New Tag"))
def add_subtag_callback(self, tag_id: int):
logging.info(f"adding {tag_id}")
# tag = self.lib.get_tag(self.tag_id)
# TODO: Create a single way to update tags and refresh library data
# new = self.build_tag()
self.tag.add_subtag(tag_id)
# self.tag = new
# self.lib.update_tag(new)
logger.info("add_subtag_callback", tag_id=tag_id)
self.subtags.add(tag_id)
self.set_subtags()
# self.on_edit.emit(self.build_tag())
def remove_subtag_callback(self, tag_id: int):
logging.info(f"removing {tag_id}")
# tag = self.lib.get_tag(self.tag_id)
# TODO: Create a single way to update tags and refresh library data
# new = self.build_tag()
self.tag.remove_subtag(tag_id)
# self.tag = new
# self.lib.update_tag(new)
logger.info("removing subtag", tag_id=tag_id)
self.subtags.remove(tag_id)
self.set_subtags()
# self.on_edit.emit(self.build_tag())
def set_subtags(self):
while self.scroll_layout.itemAt(0):
self.scroll_layout.takeAt(0).widget().deleteLater()
logging.info(f"Setting {self.tag.subtag_ids}")
c = QWidget()
l = QVBoxLayout(c)
l.setContentsMargins(0, 0, 0, 0)
l.setSpacing(3)
for tag_id in self.tag.subtag_ids:
tw = TagWidget(self.lib, self.lib.get_tag(tag_id), False, True)
tw.on_remove.connect(
lambda checked=False, t=tag_id: self.remove_subtag_callback(t)
)
l.addWidget(tw)
layout = QVBoxLayout(c)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(3)
for tag_id in self.subtags:
tag = self.lib.get_tag(tag_id)
tw = TagWidget(tag, False, True)
tw.on_remove.connect(lambda t=tag_id: self.remove_subtag_callback(t))
layout.addWidget(tw)
self.scroll_layout.addWidget(c)
def set_tag(self, tag: Tag):
# tag = self.lib.get_tag(tag_id)
logger.info("setting tag", tag=tag)
self.name_field.setText(tag.name)
self.shorthand_field.setText(tag.shorthand)
self.aliases_field.setText("\n".join(tag.aliases))
self.shorthand_field.setText(tag.shorthand or "")
# TODO: Implement aliases
# self.aliases_field.setText("\n".join(tag.aliases))
self.set_subtags()
self.color_field.setCurrentIndex(TAG_COLORS.index(tag.color.lower()))
# self.tag_id = tag.id
# select item in self.color_field where the userData value matched tag.color
for i in range(self.color_field.count()):
if self.color_field.itemData(i) == tag.color:
self.color_field.setCurrentIndex(i)
break
self.tag = tag
def build_tag(self) -> Tag:
# tag: Tag = self.tag
# if self.tag_id >= 0:
# tag = self.lib.get_tag(self.tag_id)
# else:
# tag = Tag(-1, '', '', [], [], '')
new_tag: Tag = Tag(
id=self.tag.id,
name=self.name_field.text(),
shorthand=self.shorthand_field.text(),
aliases=self.aliases_field.toPlainText().split("\n"),
subtags_ids=self.tag.subtag_ids,
color=self.color_field.currentText().lower(),
)
logging.info(f"built {new_tag}")
return new_tag
color = self.color_field.currentData() or TagColor.DEFAULT
# NOTE: The callback and signal do the same thing, I'm currently
# transitioning from using callbacks to the Qt method of using signals.
# self.tag_updated.emit(new_tag)
# self.callback(new_tag)
tag = self.tag
# def on_return(self, callback, text:str):
# if text and self.first_tag_id >= 0:
# callback(self.first_tag_id)
# self.search_field.setText('')
# self.update_tags('')
# else:
# self.search_field.setFocus()
# self.parentWidget().hide()
tag.name = self.name_field.text()
tag.shorthand = self.shorthand_field.text()
tag.color = color
logger.info("built tag", tag=tag)
return tag

View File

@@ -15,7 +15,7 @@ from PySide6.QtWidgets import (
QListView,
)
from src.core.library import ItemType, Library
from src.core.utils.missing_files import MissingRegistry
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.widgets.progress import ProgressWidget
@@ -28,10 +28,10 @@ if typing.TYPE_CHECKING:
class DeleteUnlinkedEntriesModal(QWidget):
done = Signal()
def __init__(self, library: "Library", driver: "QtDriver"):
def __init__(self, driver: "QtDriver", tracker: MissingRegistry):
super().__init__()
self.lib = library
self.driver = driver
self.tracker = tracker
self.setWindowTitle("Delete Unlinked Entries")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumSize(500, 400)
@@ -42,7 +42,7 @@ class DeleteUnlinkedEntriesModal(QWidget):
self.desc_widget.setObjectName("descriptionLabel")
self.desc_widget.setWordWrap(True)
self.desc_widget.setText(f"""
Are you sure you want to delete the following {len(self.lib.missing_files)} entries?
Are you sure you want to delete the following {self.tracker.missing_files_count} entries?
""")
self.desc_widget.setAlignment(Qt.AlignmentFlag.AlignCenter)
@@ -73,35 +73,38 @@ class DeleteUnlinkedEntriesModal(QWidget):
def refresh_list(self):
self.desc_widget.setText(f"""
Are you sure you want to delete the following {len(self.lib.missing_files)} entries?
Are you sure you want to delete the following {self.tracker.missing_files_count} entries?
""")
self.model.clear()
for i in self.lib.missing_files:
self.model.appendRow(QStandardItem(str(i)))
for i in self.tracker.missing_files:
self.model.appendRow(QStandardItem(str(i.path)))
def delete_entries(self):
iterator = FunctionIterator(self.lib.remove_missing_files)
pw = ProgressWidget(
window_title="Deleting Entries",
label_text="",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.missing_files),
maximum=self.tracker.missing_files_count,
)
pw.show()
iterator.value.connect(lambda x: pw.update_progress(x[0] + 1))
iterator = FunctionIterator(self.tracker.execute_deletion)
files_count = self.tracker.missing_files_count
iterator.value.connect(
lambda x: pw.update_label(
f"Deleting {x[0]+1}/{len(self.lib.missing_files)} Unlinked Entries"
lambda idx: (
pw.update_progress(idx),
pw.update_label(f"Deleting {idx}/{files_count} Unlinked Entries"),
)
)
iterator.value.connect(
lambda x: self.driver.purge_item_from_navigation(ItemType.ENTRY, x[1])
)
r = CustomRunnable(lambda: iterator.run())
r = CustomRunnable(iterator.run)
QThreadPool.globalInstance().start(r)
r.done.connect(lambda: (pw.hide(), pw.deleteLater(), self.done.emit()))
r.done.connect(
lambda: (
pw.hide(),
pw.deleteLater(),
self.done.emit(),
)
)

View File

@@ -19,13 +19,17 @@ from PySide6.QtWidgets import (
from src.core.library import Library
from src.qt.widgets.panel import PanelWidget
from src.core.constants import LibraryPrefs
class FileExtensionItemDelegate(QStyledItemDelegate):
def setModelData(self, editor, model, index):
if isinstance(editor, QLineEdit):
if editor.text() and not editor.text().startswith("."):
editor.setText(f".{editor.text()}")
if (
isinstance(editor, QLineEdit)
and editor.text()
and not editor.text().startswith(".")
):
editor.setText(f".{editor.text()}")
super().setModelData(editor, model, index)
@@ -43,7 +47,7 @@ class FileExtensionModal(PanelWidget):
self.root_layout.setContentsMargins(6, 6, 6, 6)
# Create Table Widget --------------------------------------------------
self.table = QTableWidget(len(self.lib.ext_list), 1)
self.table = QTableWidget(len(self.lib.prefs(LibraryPrefs.EXTENSION_LIST)), 1)
self.table.horizontalHeader().setVisible(False)
self.table.verticalHeader().setVisible(False)
self.table.horizontalHeader().setStretchLastSection(True)
@@ -65,9 +69,12 @@ class FileExtensionModal(PanelWidget):
self.mode_label.setText("List Mode:")
self.mode_combobox = QComboBox()
self.mode_combobox.setEditable(False)
self.mode_combobox.addItem("Exclude")
self.mode_combobox.addItem("Include")
self.mode_combobox.setCurrentIndex(0 if self.lib.is_exclude_list else 1)
self.mode_combobox.addItem("Exclude")
is_exclude_list = int(bool(self.lib.prefs(LibraryPrefs.IS_EXCLUDE_LIST)))
self.mode_combobox.setCurrentIndex(is_exclude_list)
self.mode_combobox.currentIndexChanged.connect(
lambda i: self.update_list_mode(i)
)
@@ -91,23 +98,23 @@ class FileExtensionModal(PanelWidget):
Args:
mode (int): The list mode, given by the index of the mode inside
the mode combobox. 0 for "Exclude", 1 for "Include".
the mode combobox. 1 for "Exclude", 0 for "Include".
"""
if mode == 0:
self.lib.is_exclude_list = True
elif mode == 1:
self.lib.is_exclude_list = False
self.lib.set_prefs(LibraryPrefs.IS_EXCLUDE_LIST, bool(mode))
def refresh_list(self):
for i, ext in enumerate(self.lib.ext_list):
for i, ext in enumerate(self.lib.prefs(LibraryPrefs.EXTENSION_LIST)):
self.table.setItem(i, 0, QTableWidgetItem(ext))
def add_item(self):
self.table.insertRow(self.table.rowCount())
def save(self):
self.lib.ext_list.clear()
extensions = []
for i in range(self.table.rowCount()):
ext = self.table.item(i, 0)
if ext and ext.text():
self.lib.ext_list.append(ext.text().lower())
if ext and ext.text().strip():
extensions.append(ext.text().strip().lower())
# save preference
self.lib.set_prefs(LibraryPrefs.EXTENSION_LIST, extensions)

View File

@@ -3,7 +3,6 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import os
import typing
from PySide6.QtCore import Qt
@@ -17,6 +16,7 @@ from PySide6.QtWidgets import (
)
from src.core.library import Library
from src.core.utils.dupe_files import DupeRegistry
from src.qt.modals.mirror_entities import MirrorEntriesModal
# Only import for type checking/autocompletion, will not be imported at runtime.
@@ -32,12 +32,14 @@ class FixDupeFilesModal(QWidget):
self.driver = driver
self.count = -1
self.filename = ""
self.setWindowTitle(f"Fix Duplicate Files")
self.setWindowTitle("Fix Duplicate Files")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumSize(400, 300)
self.root_layout = QVBoxLayout(self)
self.root_layout.setContentsMargins(6, 6, 6, 6)
self.tracker = DupeRegistry(library=self.lib)
self.desc_widget = QLabel()
self.desc_widget.setObjectName("descriptionLabel")
self.desc_widget.setWordWrap(True)
@@ -80,13 +82,14 @@ class FixDupeFilesModal(QWidget):
self.open_button = QPushButton()
self.open_button.setText("&Load DupeGuru File")
self.open_button.clicked.connect(lambda: self.select_file())
self.open_button.clicked.connect(self.select_file)
self.mirror_modal = MirrorEntriesModal(self.driver, self.tracker)
self.mirror_modal.done.connect(self.refresh_dupes)
self.mirror_button = QPushButton()
self.mirror_modal = MirrorEntriesModal(self.lib, self.driver)
self.mirror_modal.done.connect(lambda: self.refresh_dupes())
self.mirror_button.setText("&Mirror Entries")
self.mirror_button.clicked.connect(lambda: self.mirror_modal.show())
self.mirror_button.clicked.connect(self.mirror_modal.show)
self.mirror_desc = QLabel()
self.mirror_desc.setWordWrap(True)
self.mirror_desc.setText(
@@ -134,7 +137,7 @@ class FixDupeFilesModal(QWidget):
self.root_layout.addStretch(1)
self.root_layout.addWidget(self.button_container)
self.set_dupe_count(self.count)
self.set_dupe_count(-1)
def select_file(self):
qfd = QFileDialog(self, "Open DupeGuru Results File", str(self.lib.library_dir))
@@ -155,15 +158,14 @@ class FixDupeFilesModal(QWidget):
self.mirror_modal.refresh_list()
def refresh_dupes(self):
self.lib.refresh_dupe_files(self.filename)
self.set_dupe_count(len(self.lib.dupe_files))
self.tracker.refresh_dupe_files(self.filename)
self.set_dupe_count(self.tracker.groups_count)
def set_dupe_count(self, count: int):
self.count = count
if self.count < 0:
if count < 0:
self.mirror_button.setDisabled(True)
self.dupe_count.setText(f"Duplicate File Matches: N/A")
elif self.count == 0:
self.dupe_count.setText("Duplicate File Matches: N/A")
elif count == 0:
self.mirror_button.setDisabled(True)
self.dupe_count.setText(f"Duplicate File Matches: {count}")
else:

View File

@@ -3,13 +3,13 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import typing
from PySide6.QtCore import Qt, QThreadPool
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton
from src.core.library import Library
from src.core.utils.missing_files import MissingRegistry
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.modals.delete_unlinked import DeleteUnlinkedEntriesModal
@@ -22,18 +22,14 @@ if typing.TYPE_CHECKING:
from src.qt.ts_qt import QtDriver
ERROR = "[ERROR]"
WARNING = "[WARNING]"
INFO = "[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
class FixUnlinkedEntriesModal(QWidget):
def __init__(self, library: "Library", driver: "QtDriver"):
super().__init__()
self.lib = library
self.driver = driver
self.tracker = MissingRegistry(library=self.lib)
self.missing_count = -1
self.dupe_count = -1
self.setWindowTitle("Fix Unlinked Entries")
@@ -50,14 +46,6 @@ class FixUnlinkedEntriesModal(QWidget):
"""Each library entry is linked to a file in one of your directories. If a file linked to an entry is moved or deleted outside of TagStudio, it is then considered unlinked. Unlinked entries may be automatically relinked via searching your directories, manually relinked by the user, or deleted if desired."""
)
self.dupe_desc_widget = QLabel()
self.dupe_desc_widget.setObjectName("dupeDescriptionLabel")
self.dupe_desc_widget.setWordWrap(True)
self.dupe_desc_widget.setStyleSheet("text-align:left;")
self.dupe_desc_widget.setText(
"""Duplicate entries are defined as multiple entries which point to the same file on disk. Merging these will combine the tags and metadata from all duplicates into a single consolidated entry. These are not to be confused with "duplicate files", which are duplicates of your files themselves outside of TagStudio."""
)
self.missing_count_label = QLabel()
self.missing_count_label.setObjectName("missingCountLabel")
self.missing_count_label.setStyleSheet("font-weight:bold;" "font-size:14px;")
@@ -70,42 +58,37 @@ class FixUnlinkedEntriesModal(QWidget):
self.refresh_unlinked_button = QPushButton()
self.refresh_unlinked_button.setText("&Refresh All")
self.refresh_unlinked_button.clicked.connect(
lambda: self.refresh_missing_files()
)
self.refresh_unlinked_button.clicked.connect(self.refresh_missing_files)
self.merge_class = MergeDuplicateEntries(self.lib, self.driver)
self.relink_class = RelinkUnlinkedEntries(self.lib, self.driver)
self.relink_class = RelinkUnlinkedEntries(self.tracker)
self.search_button = QPushButton()
self.search_button.setText("&Search && Relink")
self.relink_class.done.connect(
lambda: self.refresh_and_repair_dupe_entries(self.merge_class)
# refresh the grid
lambda: (
self.driver.filter_items(),
self.refresh_missing_files(),
)
)
self.search_button.clicked.connect(lambda: self.relink_class.repair_entries())
self.refresh_dupe_button = QPushButton()
self.refresh_dupe_button.setText("Refresh Duplicate Entries")
self.refresh_dupe_button.clicked.connect(lambda: self.refresh_dupe_entries())
self.merge_dupe_button = QPushButton()
self.merge_dupe_button.setText("&Merge Duplicate Entries")
self.merge_class.done.connect(lambda: self.set_dupe_count(-1))
self.merge_class.done.connect(lambda: self.set_missing_count(-1))
self.merge_class.done.connect(lambda: self.driver.filter_items())
self.merge_dupe_button.clicked.connect(lambda: self.merge_class.merge_entries())
self.search_button.clicked.connect(self.relink_class.repair_entries)
self.manual_button = QPushButton()
self.manual_button.setText("&Manual Relink")
self.manual_button.setHidden(True)
self.delete_button = QPushButton()
self.delete_modal = DeleteUnlinkedEntriesModal(self.lib, self.driver)
self.delete_modal = DeleteUnlinkedEntriesModal(self.driver, self.tracker)
self.delete_modal.done.connect(
lambda: self.set_missing_count(len(self.lib.missing_files))
lambda: (
self.set_missing_count(self.tracker.missing_files_count),
# refresh the grid
self.driver.filter_items(),
)
)
self.delete_modal.done.connect(lambda: self.driver.update_thumbs())
self.delete_button.setText("De&lete Unlinked Entries")
self.delete_button.clicked.connect(lambda: self.delete_modal.show())
self.delete_button.clicked.connect(self.delete_modal.show)
self.button_container = QWidget()
self.button_layout = QHBoxLayout(self.button_container)
@@ -122,83 +105,35 @@ class FixUnlinkedEntriesModal(QWidget):
self.root_layout.addWidget(self.unlinked_desc_widget)
self.root_layout.addWidget(self.refresh_unlinked_button)
self.root_layout.addWidget(self.search_button)
self.manual_button.setHidden(True)
self.root_layout.addWidget(self.manual_button)
self.root_layout.addWidget(self.delete_button)
self.root_layout.addStretch(1)
self.root_layout.addWidget(self.dupe_count_label)
self.root_layout.addWidget(self.dupe_desc_widget)
self.root_layout.addWidget(self.refresh_dupe_button)
self.root_layout.addWidget(self.merge_dupe_button)
self.root_layout.addStretch(2)
self.root_layout.addWidget(self.button_container)
self.set_missing_count(self.missing_count)
self.set_dupe_count(self.dupe_count)
def refresh_missing_files(self):
iterator = FunctionIterator(self.lib.refresh_missing_files)
pw = ProgressWidget(
window_title="Scanning Library",
label_text="Scanning Library for Unlinked Entries...",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.entries),
maximum=self.lib.entries_count,
)
pw.show()
iterator = FunctionIterator(self.tracker.refresh_missing_files)
iterator.value.connect(lambda v: pw.update_progress(v + 1))
r = CustomRunnable(lambda: iterator.run())
r = CustomRunnable(iterator.run)
QThreadPool.globalInstance().start(r)
r.done.connect(
lambda: (
pw.hide(),
pw.deleteLater(),
self.set_missing_count(len(self.lib.missing_files)),
self.set_missing_count(self.tracker.missing_files_count),
self.delete_modal.refresh_list(),
self.refresh_dupe_entries(),
)
)
def refresh_dupe_entries(self):
iterator = FunctionIterator(self.lib.refresh_dupe_entries)
pw = ProgressWidget(
window_title="Scanning Library",
label_text="Scanning Library for Duplicate Entries...",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.entries),
)
pw.show()
iterator.value.connect(lambda v: pw.update_progress(v + 1))
r = CustomRunnable(lambda: iterator.run())
QThreadPool.globalInstance().start(r)
r.done.connect(
lambda: (
pw.hide(),
pw.deleteLater(),
self.set_dupe_count(len(self.lib.dupe_entries)),
)
)
def refresh_and_repair_dupe_entries(self, merge_class: MergeDuplicateEntries):
iterator = FunctionIterator(self.lib.refresh_dupe_entries)
pw = ProgressWidget(
window_title="Scanning Library",
label_text="Scanning Library for Duplicate Entries...",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.entries),
)
pw.show()
iterator.value.connect(lambda v: pw.update_progress(v + 1))
r = CustomRunnable(lambda: iterator.run())
QThreadPool.globalInstance().start(r)
r.done.connect(
lambda: (
pw.hide(), # type: ignore
pw.deleteLater(), # type: ignore
self.set_dupe_count(len(self.lib.dupe_entries)),
merge_class.merge_entries(),
)
)
@@ -208,23 +143,8 @@ class FixUnlinkedEntriesModal(QWidget):
self.search_button.setDisabled(True)
self.delete_button.setDisabled(True)
self.missing_count_label.setText("Unlinked Entries: N/A")
elif self.missing_count == 0:
self.search_button.setDisabled(True)
self.delete_button.setDisabled(True)
self.missing_count_label.setText(f"Unlinked Entries: {count}")
else:
self.search_button.setDisabled(False)
self.delete_button.setDisabled(False)
# disable buttons if there are no files to fix
self.search_button.setDisabled(self.missing_count == 0)
self.delete_button.setDisabled(self.missing_count == 0)
self.missing_count_label.setText(f"Unlinked Entries: {count}")
def set_dupe_count(self, count: int):
self.dupe_count = count
if self.dupe_count < 0:
self.dupe_count_label.setText("Duplicate Entries: N/A")
self.merge_dupe_button.setDisabled(True)
elif self.dupe_count == 0:
self.dupe_count_label.setText(f"Duplicate Entries: {count}")
self.merge_dupe_button.setDisabled(True)
else:
self.dupe_count_label.setText(f"Duplicate Entries: {count}")
self.merge_dupe_button.setDisabled(False)

View File

@@ -3,10 +3,11 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import math
import typing
from dataclasses import dataclass, field
import structlog
from PySide6.QtCore import Qt
from PySide6.QtWidgets import (
QWidget,
@@ -18,142 +19,143 @@ from PySide6.QtWidgets import (
QFrame,
)
from src.core.enums import FieldID
from src.core.library import Library, Tag
from src.core.constants import TAG_FAVORITE, TAG_ARCHIVED
from src.core.library import Tag, Library
from src.core.library.alchemy.fields import _FieldID
from src.core.palette import ColorType, get_tag_color
from src.qt.flowlayout import FlowLayout
# Only import for type checking/autocompletion, will not be imported at runtime.
if typing.TYPE_CHECKING:
from src.qt.ts_qt import QtDriver
logger = structlog.get_logger(__name__)
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
@dataclass
class BranchData:
dirs: dict[str, "BranchData"] = field(default_factory=dict)
files: list[str] = field(default_factory=list)
tag: Tag | None = None
def add_folders_to_tree(
library: Library, tree: BranchData, items: tuple[str, ...]
) -> BranchData:
branch = tree
for folder in items:
if folder not in branch.dirs:
# TODO - subtags
new_tag = Tag(name=folder)
library.add_tag(new_tag)
branch.dirs[folder] = BranchData(tag=new_tag)
branch.tag = new_tag
branch = branch.dirs[folder]
return branch
def folders_to_tags(library: Library):
logging.info("Converting folders to Tags")
tree: dict = dict(dirs={})
logger.info("Converting folders to Tags")
tree = BranchData()
def add_tag_to_tree(items: list[Tag]):
branch = tree
for tag in items:
if tag.name not in branch["dirs"]:
branch["dirs"][tag.name] = dict(dirs={}, tag=tag)
branch = branch["dirs"][tag.name]
def add_folders_to_tree(items: list[str]) -> Tag:
branch: dict = tree
for folder in items:
if folder not in branch["dirs"]:
new_tag = Tag(
-1,
folder,
"",
[],
([branch["tag"].id] if "tag" in branch else []),
"",
)
library.add_tag_to_library(new_tag)
branch["dirs"][folder] = dict(dirs={}, tag=new_tag)
branch = branch["dirs"][folder]
return branch.get("tag")
if tag.name not in branch.dirs:
branch.dirs[tag.name] = BranchData()
branch = branch.dirs[tag.name]
for tag in library.tags:
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)
for entry in library.entries:
folders = list(entry.path.parts)
if len(folders) == 1 and folders[0] == "":
for entry in library.get_entries():
folders = entry.path.parts[1:-1]
if not folders:
continue
tag = add_folders_to_tree(folders)
if tag:
if not entry.has_tag(library, tag.id):
entry.add_tag(library, tag.id, FieldID.TAGS)
logging.info("Done")
tag = add_folders_to_tree(library, tree, folders).tag
if tag and not entry.has_tag(tag):
library.add_field_tag(entry, tag, _FieldID.TAGS.name, create_field=True)
logger.info("Done")
def reverse_tag(library: Library, tag: Tag, list: list[Tag]) -> list[Tag]:
if list is not None:
list.append(tag)
else:
list = [tag]
def reverse_tag(library: Library, tag: Tag, items: list[Tag] | None) -> list[Tag]:
items = items or []
items.append(tag)
if len(tag.subtag_ids) == 0:
list.reverse()
return list
else:
for subtag_id in tag.subtag_ids:
subtag = library.get_tag(subtag_id)
return reverse_tag(library, subtag, list)
if not tag.subtag_ids:
items.reverse()
return items
for subtag_id in tag.subtag_ids:
subtag = library.get_tag(subtag_id)
return reverse_tag(library, subtag, items)
# =========== UI ===========
def generate_preview_data(library: Library):
tree: dict = dict(dirs={}, files=[])
def generate_preview_data(library: Library) -> BranchData:
tree = BranchData()
def add_tag_to_tree(items: list[Tag]):
branch: dict = tree
branch = tree
for tag in items:
if tag.name not in branch["dirs"]:
branch["dirs"][tag.name] = dict(dirs={}, tag=tag, files=[])
branch = branch["dirs"][tag.name]
if tag.name not in branch.dirs:
branch.dirs[tag.name] = BranchData(tag=tag)
branch = branch.dirs[tag.name]
def add_folders_to_tree(items: list[str]) -> dict:
branch: dict = tree
def _add_folders_to_tree(items: typing.Sequence[str]) -> BranchData:
branch = tree
for folder in items:
if folder not in branch["dirs"]:
new_tag = Tag(-1, folder, "", [], [], "green")
branch["dirs"][folder] = dict(dirs={}, tag=new_tag, files=[])
branch = branch["dirs"][folder]
if folder not in branch.dirs:
new_tag = Tag(name=folder)
branch.dirs[folder] = BranchData(tag=new_tag)
branch = branch.dirs[folder]
return branch
for tag in library.tags:
if tag.id in (TAG_FAVORITE, TAG_ARCHIVED):
continue
reversed_tag = reverse_tag(library, tag, None)
add_tag_to_tree(reversed_tag)
for entry in library.entries:
folders = list(entry.path.parts)
if len(folders) == 1 and folders[0] == "":
for entry in library.get_entries():
folders = entry.path.parts[1:-1]
if not folders:
continue
branch = add_folders_to_tree(folders)
branch = _add_folders_to_tree(folders)
if branch:
field_indexes = library.get_field_index_in_entry(entry, 6)
has_tag = False
for index in field_indexes:
content = library.get_field_attr(entry.fields[index], "content")
for tag_id in content:
tag = library.get_tag(tag_id)
if tag.name == branch["tag"].name:
for tag_field in entry.tag_box_fields:
for tag in tag_field.tags:
if tag.name == branch.tag.name:
has_tag = True
break
if not has_tag:
branch["files"].append(entry.filename)
branch.files.append(entry.path.name)
def cut_branches_adding_nothing(branch: dict):
folders = set(branch["dirs"].keys())
def cut_branches_adding_nothing(branch: BranchData) -> bool:
folders = list(branch.dirs.keys())
for folder in folders:
cut = cut_branches_adding_nothing(branch["dirs"][folder])
cut = cut_branches_adding_nothing(branch.dirs[folder])
if cut:
branch["dirs"].pop(folder)
branch.dirs.pop(folder)
if "tag" not in branch:
return
if branch["tag"].id == -1 or len(branch["files"]) > 0: # Needs to be first
if not branch.tag:
return False
if len(branch["dirs"].keys()) == 0:
return True
if not branch.tag.id:
return False
if branch.files:
return False
return not bool(branch.dirs)
cut_branches_adding_nothing(tree)
return tree
@@ -166,7 +168,7 @@ class FoldersToTagsModal(QWidget):
self.count = -1
self.filename = ""
self.setWindowTitle(f"Create Tags From Folders")
self.setWindowTitle("Create Tags From Folders")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumSize(640, 640)
self.root_layout = QVBoxLayout(self)
@@ -242,19 +244,19 @@ class FoldersToTagsModal(QWidget):
data = generate_preview_data(self.library)
for folder in data["dirs"].values():
test = TreeItem(folder, None)
for folder in data.dirs.values():
test = TreeItem(folder)
self.scroll_layout.addWidget(test)
def set_all_branches(self, hidden: bool):
for i in reversed(range(self.scroll_layout.count())):
child = self.scroll_layout.itemAt(i).widget()
if type(child) == TreeItem:
if isinstance(child, TreeItem):
child.set_all_branches(hidden)
class TreeItem(QWidget):
def __init__(self, data: dict, parentTag: Tag):
def __init__(self, data: BranchData, parent_tag: Tag | None = None):
super().__init__()
self.setStyleSheet("QLabel{font-size: 13px}")
@@ -270,7 +272,7 @@ class TreeItem(QWidget):
self.label = QLabel()
self.tag_layout.addWidget(self.label)
self.tag_widget = ModifiedTagWidget(data["tag"], parentTag)
self.tag_widget = ModifiedTagWidget(data.tag, parent_tag)
self.tag_widget.bg_button.clicked.connect(lambda: self.hide_show())
self.tag_layout.addWidget(self.tag_widget)
@@ -284,24 +286,24 @@ class TreeItem(QWidget):
self.children_widget.setHidden(not self.children_widget.isHidden())
self.label.setText(">" if self.children_widget.isHidden() else "v")
def populate(self, data: dict):
for folder in data["dirs"].values():
item = TreeItem(folder, data["tag"])
def populate(self, data: BranchData):
for folder in data.dirs.values():
item = TreeItem(folder, data.tag)
self.children_layout.addWidget(item)
for file in data["files"]:
for file in data.files:
label = QLabel()
label.setText(" -> " + str(file))
self.children_layout.addWidget(label)
if len(data["files"]) == 0 and len(data["dirs"].values()) == 0:
self.hide_show()
else:
if data.files or data.dirs:
self.label.setText("v")
else:
self.hide_show()
def set_all_branches(self, hidden: bool):
for i in reversed(range(self.children_layout.count())):
child = self.children_layout.itemAt(i).widget()
if type(child) == TreeItem:
if isinstance(child, TreeItem):
child.set_all_branches(hidden)
self.children_widget.setHidden(hidden)
@@ -343,7 +345,7 @@ class ModifiedTagWidget(
f"border-color:{get_tag_color(ColorType.BORDER, tag.color)};"
f"border-radius: 6px;"
f"border-style:inset;"
f"border-width: {math.ceil(1*self.devicePixelRatio())}px;"
f"border-width: {math.ceil(self.devicePixelRatio())}px;"
f"padding-right: 4px;"
f"padding-bottom: 1px;"
f"padding-left: 4px;"

View File

@@ -7,6 +7,7 @@ import typing
from PySide6.QtCore import QObject, Signal, QThreadPool
from src.core.library import Library
from src.core.utils.dupe_files import DupeRegistry
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.widgets.progress import ProgressWidget
@@ -23,16 +24,17 @@ class MergeDuplicateEntries(QObject):
super().__init__()
self.lib = library
self.driver = driver
self.tracker = DupeRegistry(library=self.lib)
def merge_entries(self):
iterator = FunctionIterator(self.lib.merge_dupe_entries)
iterator = FunctionIterator(self.tracker.merge_dupe_entries)
pw = ProgressWidget(
window_title="Merging Duplicate Entries",
label_text="",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.dupe_entries),
maximum=self.tracker.groups_count,
)
pw.show()
@@ -41,6 +43,6 @@ class MergeDuplicateEntries(QObject):
lambda: (pw.update_label("Merging Duplicate Entries..."))
)
r = CustomRunnable(lambda: iterator.run())
r = CustomRunnable(iterator.run)
r.done.connect(lambda: (pw.hide(), pw.deleteLater(), self.done.emit()))
QThreadPool.globalInstance().start(r)

View File

@@ -17,7 +17,7 @@ from PySide6.QtWidgets import (
QListView,
)
from src.core.library import Library
from src.core.utils.dupe_files import DupeRegistry
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.widgets.progress import ProgressWidget
@@ -30,21 +30,22 @@ if typing.TYPE_CHECKING:
class MirrorEntriesModal(QWidget):
done = Signal()
def __init__(self, library: "Library", driver: "QtDriver"):
def __init__(self, driver: "QtDriver", tracker: DupeRegistry):
super().__init__()
self.lib = library
self.driver = driver
self.setWindowTitle(f"Mirror Entries")
self.setWindowTitle("Mirror Entries")
self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.setMinimumSize(500, 400)
self.root_layout = QVBoxLayout(self)
self.root_layout.setContentsMargins(6, 6, 6, 6)
self.tracker = tracker
self.desc_widget = QLabel()
self.desc_widget.setObjectName("descriptionLabel")
self.desc_widget.setWordWrap(True)
self.desc_widget.setText(f"""
Are you sure you want to mirror the following {len(self.lib.dupe_files)} Entries?
Are you sure you want to mirror the following {self.tracker.groups_count} Entries?
""")
self.desc_widget.setAlignment(Qt.AlignmentFlag.AlignCenter)
@@ -66,7 +67,7 @@ class MirrorEntriesModal(QWidget):
self.mirror_button = QPushButton()
self.mirror_button.setText("&Mirror")
self.mirror_button.clicked.connect(self.hide)
self.mirror_button.clicked.connect(lambda: self.mirror_entries())
self.mirror_button.clicked.connect(self.mirror_entries)
self.button_layout.addWidget(self.mirror_button)
self.root_layout.addWidget(self.desc_widget)
@@ -75,45 +76,30 @@ class MirrorEntriesModal(QWidget):
def refresh_list(self):
self.desc_widget.setText(f"""
Are you sure you want to mirror the following {len(self.lib.dupe_files)} Entries?
Are you sure you want to mirror the following {self.tracker.groups_count} Entries?
""")
self.model.clear()
for i in self.lib.dupe_files:
for i in self.tracker.groups:
self.model.appendRow(QStandardItem(str(i)))
def mirror_entries(self):
# pb = QProgressDialog('', None, 0, len(self.lib.dupe_files))
# # pb.setMaximum(len(self.lib.missing_files))
# pb.setFixedSize(432, 112)
# pb.setWindowFlags(pb.windowFlags() & ~Qt.WindowType.WindowCloseButtonHint)
# pb.setWindowTitle('Mirroring Entries')
# pb.setWindowModality(Qt.WindowModality.ApplicationModal)
# pb.show()
# r = CustomRunnable(lambda: self.mirror_entries_runnable(pb))
# r.done.connect(lambda: self.done.emit())
# r.done.connect(lambda: self.driver.preview_panel.refresh())
# # r.done.connect(lambda: self.model.clear())
# # QThreadPool.globalInstance().start(r)
# r.run()
iterator = FunctionIterator(self.mirror_entries_runnable)
pw = ProgressWidget(
window_title="Mirroring Entries",
label_text=f"Mirroring 1/{len(self.lib.dupe_files)} Entries...",
label_text=f"Mirroring 1/{self.tracker.groups_count} Entries...",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.dupe_files),
maximum=self.tracker.groups_count,
)
pw.show()
iterator.value.connect(lambda x: pw.update_progress(x + 1))
iterator.value.connect(
lambda x: pw.update_label(
f"Mirroring {x+1}/{len(self.lib.dupe_files)} Entries..."
f"Mirroring {x + 1}/{self.tracker.groups_count} Entries..."
)
)
r = CustomRunnable(lambda: iterator.run())
r = CustomRunnable(iterator.run)
QThreadPool.globalInstance().start(r)
r.done.connect(
lambda: (
@@ -126,15 +112,11 @@ class MirrorEntriesModal(QWidget):
def mirror_entries_runnable(self):
mirrored: list = []
for i, dupe in enumerate(self.lib.dupe_files):
# pb.setValue(i)
# pb.setLabelText(f'Mirroring {i}/{len(self.lib.dupe_files)} Entries')
entry_id_1 = self.lib.get_entry_id_from_filepath(dupe[0])
entry_id_2 = self.lib.get_entry_id_from_filepath(dupe[1])
self.lib.mirror_entry_fields([entry_id_1, entry_id_2])
lib = self.driver.lib
for i, entries in enumerate(self.tracker.groups):
lib.mirror_entry_fields(*entries)
sleep(0.005)
yield i
for d in mirrored:
self.lib.dupe_files.remove(d)
# self.driver.filter_items('')
# self.done.emit()
self.tracker.groups.remove(d)

View File

@@ -2,60 +2,51 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import typing
from PySide6.QtCore import QObject, Signal, QThreadPool
from src.core.library import Library
from src.core.utils.missing_files import MissingRegistry
from src.qt.helpers.function_iterator import FunctionIterator
from src.qt.helpers.custom_runnable import CustomRunnable
from src.qt.widgets.progress import ProgressWidget
# Only import for type checking/autocompletion, will not be imported at runtime.
if typing.TYPE_CHECKING:
from src.qt.ts_qt import QtDriver
class RelinkUnlinkedEntries(QObject):
done = Signal()
def __init__(self, library: "Library", driver: "QtDriver"):
def __init__(self, tracker: MissingRegistry):
super().__init__()
self.lib = library
self.driver = driver
self.fixed = 0
self.tracker = tracker
def repair_entries(self):
iterator = FunctionIterator(self.lib.fix_missing_files)
iterator = FunctionIterator(self.tracker.fix_missing_files)
pw = ProgressWidget(
window_title="Relinking Entries",
label_text="",
cancel_button_text=None,
minimum=0,
maximum=len(self.lib.missing_files),
maximum=self.tracker.missing_files_count,
)
pw.show()
iterator.value.connect(lambda x: pw.update_progress(x[0] + 1))
iterator.value.connect(
lambda x: (
self.increment_fixed() if x[1] else (),
lambda idx: (
pw.update_progress(idx),
pw.update_label(
f"Attempting to Relink {x[0]+1}/{len(self.lib.missing_files)} Entries, {self.fixed} Successfully Relinked"
f"Attempting to Relink {idx}/{self.tracker.missing_files_count} Entries. "
f"{self.tracker.files_fixed_count} Successfully Relinked."
),
)
)
r = CustomRunnable(lambda: iterator.run())
r = CustomRunnable(iterator.run)
r.done.connect(
lambda: (pw.hide(), pw.deleteLater(), self.done.emit(), self.reset_fixed())
lambda: (
pw.hide(),
pw.deleteLater(),
self.done.emit(),
)
)
QThreadPool.globalInstance().start(r)
def increment_fixed(self):
self.fixed += 1
def reset_fixed(self):
self.fixed = 0

View File

@@ -12,7 +12,8 @@ from PySide6.QtWidgets import (
QFrame,
)
from src.core.library import Library
from src.core.library import Library, Tag
from src.core.library.alchemy.enums import FilterState
from src.qt.widgets.panel import PanelWidget, PanelModal
from src.qt.widgets.tag import TagWidget
from src.qt.modals.build_tag import BuildTagPanel
@@ -21,7 +22,7 @@ from src.qt.modals.build_tag import BuildTagPanel
class TagDatabasePanel(PanelWidget):
tag_chosen = Signal(int)
def __init__(self, library):
def __init__(self, library: Library):
super().__init__()
self.lib: Library = library
# self.callback = callback
@@ -38,7 +39,7 @@ class TagDatabasePanel(PanelWidget):
self.search_field.setMinimumSize(QSize(0, 32))
self.search_field.setPlaceholderText("Search Tags")
self.search_field.textEdited.connect(
lambda x=self.search_field.text(): self.update_tags(x)
lambda: self.update_tags(self.search_field.text())
)
self.search_field.returnPressed.connect(
lambda checked=False: self.on_return(self.search_field.text())
@@ -73,7 +74,7 @@ class TagDatabasePanel(PanelWidget):
self.root_layout.addWidget(self.search_field)
self.root_layout.addWidget(self.scroll_area)
self.update_tags("")
self.update_tags()
# def reset(self):
# self.search_field.setText('')
@@ -84,60 +85,47 @@ class TagDatabasePanel(PanelWidget):
if text and self.first_tag_id >= 0:
# callback(self.first_tag_id)
self.search_field.setText("")
self.update_tags("")
self.update_tags()
else:
self.search_field.setFocus()
self.parentWidget().hide()
def update_tags(self, query: str):
def update_tags(self, query: str | None = None):
# TODO: Look at recycling rather than deleting and reinitializing
while self.scroll_layout.itemAt(0):
self.scroll_layout.takeAt(0).widget().deleteLater()
# If there is a query, get a list of tag_ids that match, otherwise return all
if query:
tags = self.lib.search_tags(query, include_cluster=True)[
: self.tag_limit - 1
]
else:
# Get tag ids to keep this behaviorally identical
tags = [t.id for t in self.lib.tags]
tags = self.lib.search_tags(FilterState(path=query, page_size=self.tag_limit))
first_id_set = False
for tag_id in tags:
if not first_id_set:
self.first_tag_id = tag_id
first_id_set = True
for tag in tags:
container = QWidget()
row = QHBoxLayout(container)
row.setContentsMargins(0, 0, 0, 0)
row.setSpacing(3)
tw = TagWidget(self.lib, self.lib.get_tag(tag_id), True, False)
tw.on_edit.connect(
lambda checked=False, t=self.lib.get_tag(tag_id): (self.edit_tag(t.id))
)
row.addWidget(tw)
tag_widget = TagWidget(tag, True, False)
tag_widget.on_edit.connect(lambda checked=False, t=tag: self.edit_tag(t))
row.addWidget(tag_widget)
self.scroll_layout.addWidget(container)
self.search_field.setFocus()
def edit_tag(self, tag_id: int):
btp = BuildTagPanel(self.lib, tag_id)
# btp.on_edit.connect(lambda x: self.edit_tag_callback(x))
def edit_tag(self, tag: Tag):
build_tag_panel = BuildTagPanel(self.lib, tag=tag)
self.edit_modal = PanelModal(
btp,
self.lib.get_tag(tag_id).display_name(self.lib),
build_tag_panel,
tag.name,
"Edit Tag",
done_callback=(self.update_tags(self.search_field.text())),
has_save=True,
)
# self.edit_modal.widget.update_display_name.connect(lambda t: self.edit_modal.title_widget.setText(t))
# TODO Check Warning: Expected type 'BuildTagPanel', got 'PanelWidget' instead
self.edit_modal.saved.connect(lambda: self.edit_tag_callback(btp))
self.edit_modal.saved.connect(lambda: self.edit_tag_callback(build_tag_panel))
self.edit_modal.show()
def edit_tag_callback(self, btp: BuildTagPanel):
self.lib.update_tag(btp.build_tag())
self.lib.add_tag(btp.build_tag())
self.update_tags(self.search_field.text())
# def enterEvent(self, event: QEnterEvent) -> None:

View File

@@ -3,9 +3,9 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import math
import structlog
from PySide6.QtCore import Signal, Qt, QSize
from PySide6.QtWidgets import (
QWidget,
@@ -18,24 +18,20 @@ from PySide6.QtWidgets import (
)
from src.core.library import Library
from src.core.library.alchemy.enums import FilterState
from src.core.palette import ColorType, get_tag_color
from src.qt.widgets.panel import PanelWidget
from src.qt.widgets.tag import TagWidget
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
class TagSearchPanel(PanelWidget):
tag_chosen = Signal(int)
def __init__(self, library):
def __init__(self, library: Library):
super().__init__()
self.lib: Library = library
self.lib = library
# self.callback = callback
self.first_tag_id = None
self.tag_limit = 100
@@ -49,7 +45,7 @@ class TagSearchPanel(PanelWidget):
self.search_field.setMinimumSize(QSize(0, 32))
self.search_field.setPlaceholderText("Search Tags")
self.search_field.textEdited.connect(
lambda x=self.search_field.text(): self.update_tags(x)
lambda: self.update_tags(self.search_field.text())
)
self.search_field.returnPressed.connect(
lambda checked=False: self.on_return(self.search_field.text())
@@ -84,7 +80,7 @@ class TagSearchPanel(PanelWidget):
self.root_layout.addWidget(self.search_field)
self.root_layout.addWidget(self.scroll_area)
self.update_tags("")
self.update_tags()
# def reset(self):
# self.search_field.setText('')
@@ -101,55 +97,51 @@ class TagSearchPanel(PanelWidget):
self.search_field.setFocus()
self.parentWidget().hide()
def update_tags(self, query: str = ""):
# for c in self.scroll_layout.children():
# c.widget().deleteLater()
def update_tags(self, name: str | None = None):
while self.scroll_layout.count():
# logging.info(f"I'm deleting { self.scroll_layout.itemAt(0).widget()}")
self.scroll_layout.takeAt(0).widget().deleteLater()
found_tags = self.lib.search_tags(query, include_cluster=True)[: self.tag_limit]
self.first_tag_id = found_tags[0] if found_tags else None
found_tags = self.lib.search_tags(
FilterState(
path=name,
page_size=self.tag_limit,
)
)
for tag_id in found_tags:
for tag in found_tags:
c = QWidget()
l = QHBoxLayout(c)
l.setContentsMargins(0, 0, 0, 0)
l.setSpacing(3)
tw = TagWidget(self.lib, self.lib.get_tag(tag_id), False, False)
layout = QHBoxLayout(c)
layout.setContentsMargins(0, 0, 0, 0)
layout.setSpacing(3)
tw = TagWidget(tag, False, False)
ab = QPushButton()
ab.setMinimumSize(23, 23)
ab.setMaximumSize(23, 23)
ab.setText("+")
ab.setStyleSheet(
f"QPushButton{{"
f"background: {get_tag_color(ColorType.PRIMARY, self.lib.get_tag(tag_id).color)};"
# f'background-color: qlineargradient(x1:0, y1:0, x2:0, y2:1, stop:0 {get_tag_color(ColorType.PRIMARY, tag.color)}, stop:1.0 {get_tag_color(ColorType.BORDER, tag.color)});'
# f"border-color:{get_tag_color(ColorType.PRIMARY, tag.color)};"
f"color: {get_tag_color(ColorType.TEXT, self.lib.get_tag(tag_id).color)};"
f"background: {get_tag_color(ColorType.PRIMARY, tag.color)};"
f"color: {get_tag_color(ColorType.TEXT, tag.color)};"
f"font-weight: 600;"
f"border-color:{get_tag_color(ColorType.BORDER, self.lib.get_tag(tag_id).color)};"
f"border-color:{get_tag_color(ColorType.BORDER, tag.color)};"
f"border-radius: 6px;"
f"border-style:solid;"
f"border-width: {math.ceil(1*self.devicePixelRatio())}px;"
# f'padding-top: 1.5px;'
# f'padding-right: 4px;'
f"border-width: {math.ceil(self.devicePixelRatio())}px;"
f"padding-bottom: 5px;"
# f'padding-left: 4px;'
f"font-size: 20px;"
f"}}"
f"QPushButton::hover"
f"{{"
f"border-color:{get_tag_color(ColorType.LIGHT_ACCENT, self.lib.get_tag(tag_id).color)};"
f"color: {get_tag_color(ColorType.DARK_ACCENT, self.lib.get_tag(tag_id).color)};"
f"background: {get_tag_color(ColorType.LIGHT_ACCENT, self.lib.get_tag(tag_id).color)};"
f"border-color:{get_tag_color(ColorType.LIGHT_ACCENT, tag.color)};"
f"color: {get_tag_color(ColorType.DARK_ACCENT, tag.color)};"
f"background: {get_tag_color(ColorType.LIGHT_ACCENT, tag.color)};"
f"}}"
)
ab.clicked.connect(lambda checked=False, x=tag_id: self.tag_chosen.emit(x))
ab.clicked.connect(lambda checked=False, x=tag.id: self.tag_chosen.emit(x))
l.addWidget(tw)
l.addWidget(ab)
layout.addWidget(tw)
layout.addWidget(ab)
self.scroll_layout.addWidget(c)
self.search_field.setFocus()

View File

@@ -10,7 +10,6 @@ from PySide6.QtGui import QIntValidator
from PySide6.QtWidgets import (
QWidget,
QHBoxLayout,
QPushButton,
QLabel,
QLineEdit,
QSizePolicy,
@@ -274,7 +273,7 @@ class Pagination(QWidget, QObject):
if self.end_buffer_layout.itemAt(i):
self.end_buffer_layout.itemAt(i).widget().setHidden(True)
sbc += 1
self.current_page_field.setText((str(i + 1)))
self.current_page_field.setText(str(i + 1))
# elif index == page_count-1:
# self.start_button.setText(str(page_count))
@@ -419,7 +418,6 @@ class Pagination(QWidget, QObject):
self.validator.setTop(page_count)
# if self.current_page_index != index:
if emit:
print(f"[PAGINATION] Emitting {index}")
self.index.emit(index)
self.current_page_index = index
self.page_count = page_count
@@ -435,7 +433,7 @@ class Pagination(QWidget, QObject):
button.is_connected = True
def _populate_buffer_buttons(self):
for i in range(max(self.buffer_page_count * 2, 5)):
for _ in range(max(self.buffer_page_count * 2, 5)):
button = QPushButtonWrapper()
button.setMinimumSize(self.button_size)
button.setMaximumSize(self.button_size)
@@ -443,13 +441,12 @@ class Pagination(QWidget, QObject):
# button.setMaximumHeight(self.button_size.height())
self.start_buffer_layout.addWidget(button)
for i in range(max(self.buffer_page_count * 2, 5)):
button = QPushButtonWrapper()
button.setMinimumSize(self.button_size)
button.setMaximumSize(self.button_size)
button.setHidden(True)
end_button = QPushButtonWrapper()
end_button.setMinimumSize(self.button_size)
end_button.setMaximumSize(self.button_size)
end_button.setHidden(True)
# button.setMaximumHeight(self.button_size.height())
self.end_buffer_layout.addWidget(button)
self.end_buffer_layout.addWidget(end_button)
class Validator(QIntValidator):

View File

@@ -2,13 +2,13 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
from pathlib import Path
from typing import Any
import structlog
import ujson
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
class ResourceManager:
@@ -21,12 +21,10 @@ class ResourceManager:
def __init__(self) -> None:
# Load JSON resource map
if not ResourceManager._initialized:
with open(
Path(__file__).parent / "resources.json", mode="r", encoding="utf-8"
) as f:
with open(Path(__file__).parent / "resources.json", encoding="utf-8") as f:
ResourceManager._map = ujson.load(f)
logging.info(
f"[ResourceManager] {len(ResourceManager._map.items())} resources registered"
logger.info(
"resources registered", count=len(ResourceManager._map.items())
)
ResourceManager._initialized = True

File diff suppressed because it is too large Load Diff

View File

@@ -2,37 +2,22 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import os
import traceback
from pathlib import Path
import cv2
import structlog
from PIL import Image, ImageChops, UnidentifiedImageError
from PIL.Image import DecompressionBombError
from PySide6.QtCore import (
QObject,
QThread,
Signal,
QRunnable,
Qt,
QThreadPool,
QSize,
QEvent,
QTimer,
QSettings,
)
from src.core.library import Library
from src.core.constants import DOC_TYPES, VIDEO_TYPES, IMAGE_TYPES
from src.core.library import Library
from src.core.library.alchemy.fields import _FieldID
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
class CollageIconRenderer(QObject):
@@ -52,26 +37,22 @@ class CollageIconRenderer(QObject):
keep_aspect,
):
entry = self.lib.get_entry(entry_id)
filepath = self.lib.library_dir / entry.path / entry.filename
file_type = os.path.splitext(filepath)[1].lower()[1:]
filepath = self.lib.library_dir / entry.path
color: str = ""
try:
if data_tint_mode or data_only_mode:
color = "#000000" # Black (Default)
if entry.fields:
has_any_tags: bool = False
has_content_tags: bool = False
has_meta_tags: bool = False
for field in entry.fields:
if self.lib.get_field_attr(field, "type") == "tag_box":
if self.lib.get_field_attr(field, "content"):
has_any_tags = True
if self.lib.get_field_attr(field, "id") == 7:
has_content_tags = True
elif self.lib.get_field_attr(field, "id") == 8:
has_meta_tags = True
for field in entry.tag_box_fields:
if field.tags:
has_any_tags = True
if field.type_key == _FieldID.TAGS_CONTENT.name:
has_content_tags = True
elif field.type_key == _FieldID.TAGS_META.name:
has_meta_tags = True
if has_content_tags and has_meta_tags:
color = "#28bb48" # Green
elif has_any_tags:
@@ -88,16 +69,15 @@ class CollageIconRenderer(QObject):
# collage.paste(pic, (y*thumb_size, x*thumb_size))
self.rendered.emit(pic)
if not data_only_mode:
logging.info(
f"\r{INFO} Combining [ID:{entry_id}/{len(self.lib.entries)}]: {self.get_file_color(filepath.suffix.lower())}{entry.path}{os.sep}{entry.filename}\033[0m"
logger.info(
"Combining icons",
entry=entry,
color=self.get_file_color(filepath.suffix.lower()),
)
# sys.stdout.write(f'\r{INFO} Combining [{i+1}/{len(self.lib.entries)}]: {self.get_file_color(file_type)}{entry.path}{os.sep}{entry.filename}{RESET}')
# sys.stdout.flush()
if filepath.suffix.lower() in IMAGE_TYPES:
try:
with Image.open(
str(self.lib.library_dir / entry.path / entry.filename)
) as pic:
with Image.open(str(self.lib.library_dir / entry.path)) as pic:
if keep_aspect:
pic.thumbnail(size)
else:
@@ -109,8 +89,10 @@ class CollageIconRenderer(QObject):
)
# collage.paste(pic, (y*thumb_size, x*thumb_size))
self.rendered.emit(pic)
except DecompressionBombError as e:
logging.info(f"[ERROR] One of the images was too big ({e})")
except DecompressionBombError:
logger.exception(
"One of the images was too big", entry=entry.path
)
elif filepath.suffix.lower() in VIDEO_TYPES:
video = cv2.VideoCapture(str(filepath))
video.set(
@@ -137,9 +119,7 @@ class CollageIconRenderer(QObject):
# collage.paste(pic, (y*thumb_size, x*thumb_size))
self.rendered.emit(pic)
except (UnidentifiedImageError, FileNotFoundError):
logging.info(
f"\n{ERROR} Couldn't read {entry.path}{os.sep}{entry.filename}"
)
logger.error("Couldn't read entry", entry=entry.path)
with Image.open(
str(
Path(__file__).parents[2]
@@ -153,19 +133,11 @@ class CollageIconRenderer(QObject):
# collage.paste(pic, (y*thumb_size, x*thumb_size))
self.rendered.emit(pic)
except KeyboardInterrupt:
# self.quit(save=False, backup=True)
run = False
# clear()
logging.info("\n")
logging.info(f"{INFO} Collage operation cancelled.")
clear_scr = False
except:
logging.info(f"{ERROR} {entry.path}{os.sep}{entry.filename}")
traceback.print_exc()
logging.info("Continuing...")
logger.info("Collage operation cancelled.")
except Exception:
logger.exception("render failed", entry=entry.path)
self.done.emit()
# logging.info('Done!')
def get_file_color(self, ext: str):
if ext.lower().replace(".", "", 1) == "gif":

View File

@@ -4,15 +4,14 @@
import math
import os
from types import FunctionType, MethodType
from types import MethodType
from pathlib import Path
from typing import Optional, cast, Callable, Any
from typing import Optional, Callable
from PIL import Image, ImageQt
from PySide6.QtCore import Qt, QEvent
from PySide6.QtGui import QPixmap, QEnterEvent
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel
from src.qt.helpers.qbutton_wrapper import QPushButtonWrapper
@@ -41,8 +40,8 @@ class FieldContainer(QWidget):
self.title: str = title
self.inline: bool = inline
# self.editable:bool = editable
self.copy_callback: FunctionType = None
self.edit_callback: FunctionType = None
self.copy_callback: Callable = None
self.edit_callback: Callable = None
self.remove_callback: Callable = None
button_size = 24
# self.setStyleSheet('border-style:solid;border-color:#1e1a33;border-radius:8px;border-width:2px;')
@@ -133,7 +132,7 @@ class FieldContainer(QWidget):
if callback is not None:
self.copy_button.is_connected = True
def set_edit_callback(self, callback: Optional[MethodType]):
def set_edit_callback(self, callback: Callable):
if self.edit_button.is_connected:
self.edit_button.clicked.disconnect()
@@ -142,7 +141,7 @@ class FieldContainer(QWidget):
if callback is not None:
self.edit_button.is_connected = True
def set_remove_callback(self, callback: Optional[Callable]):
def set_remove_callback(self, callback: Callable):
if self.remove_button.is_connected:
self.remove_button.clicked.disconnect()
@@ -160,9 +159,9 @@ class FieldContainer(QWidget):
self.field_layout.itemAt(0).widget().deleteLater()
self.field_layout.addWidget(widget)
def get_inner_widget(self) -> Optional["FieldWidget"]:
def get_inner_widget(self):
if self.field_layout.itemAt(0):
return cast(FieldWidget, self.field_layout.itemAt(0).widget())
return self.field_layout.itemAt(0).widget()
return None
def set_title(self, title: str):
@@ -198,8 +197,6 @@ class FieldContainer(QWidget):
class FieldWidget(QWidget):
field = dict
def __init__(self, title) -> None:
super().__init__()
# self.item = item

View File

@@ -1,14 +1,14 @@
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import contextlib
import logging
import os
import time
import typing
from enum import Enum
from functools import wraps
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING
import structlog
from PIL import Image, ImageQt
from PySide6.QtCore import Qt, QSize, QEvent
from PySide6.QtGui import QPixmap, QEnterEvent, QAction
@@ -21,8 +21,6 @@ from PySide6.QtWidgets import (
QCheckBox,
)
from src.core.enums import FieldID
from src.core.library import ItemType, Library, Entry
from src.core.constants import (
AUDIO_TYPES,
VIDEO_TYPES,
@@ -30,20 +28,49 @@ from src.core.constants import (
TAG_FAVORITE,
TAG_ARCHIVED,
)
from src.core.library import ItemType, Entry, Library
from src.core.library.alchemy.enums import FilterState
from src.core.library.alchemy.fields import _FieldID
from src.qt.flowlayout import FlowWidget
from src.qt.helpers.file_opener import FileOpenerHelper
from src.qt.widgets.thumb_renderer import ThumbRenderer
from src.qt.widgets.thumb_button import ThumbButton
if typing.TYPE_CHECKING:
from src.qt.widgets.preview_panel import PreviewPanel
if TYPE_CHECKING:
from src.qt.ts_qt import QtDriver
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
logger = structlog.get_logger(__name__)
logging.basicConfig(format="%(message)s", level=logging.INFO)
class BadgeType(Enum):
FAVORITE = "Favorite"
ARCHIVED = "Archived"
BADGE_TAGS = {
BadgeType.FAVORITE: TAG_FAVORITE,
BadgeType.ARCHIVED: TAG_ARCHIVED,
}
def badge_update_lock(func):
"""Prevent recursively triggering badge updates."""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.driver.badge_update_lock:
return
self.driver.badge_update_lock = True
try:
func(self, *args, **kwargs)
except Exception:
raise
finally:
self.driver.badge_update_lock = False
return wrapper
class ItemThumb(FlowWidget):
@@ -89,19 +116,18 @@ class ItemThumb(FlowWidget):
def __init__(
self,
mode: Optional[ItemType],
mode: ItemType,
library: Library,
panel: "PreviewPanel",
driver: "QtDriver",
thumb_size: tuple[int, int],
grid_idx: int,
):
"""Modes: entry, collation, tag_group"""
super().__init__()
self.grid_idx = grid_idx
self.lib = library
self.panel = panel
self.mode = mode
self.item_id: int = -1
self.isFavorite: bool = False
self.isArchived: bool = False
self.mode: ItemType = mode
self.driver = driver
self.item_id: int | None = None
self.thumb_size: tuple[int, int] = thumb_size
self.setMinimumSize(*thumb_size)
self.setMaximumSize(*thumb_size)
@@ -179,7 +205,7 @@ class ItemThumb(FlowWidget):
lambda ts, i, s, ext: (
self.update_thumb(ts, image=i),
self.update_size(ts, size=s),
self.set_extension(ext), # type: ignore
self.set_extension(ext),
)
)
self.thumb_button.setFlat(True)
@@ -263,54 +289,52 @@ class ItemThumb(FlowWidget):
# self.root_layout.addWidget(self.check_badges, 0, 2)
self.top_layout.addWidget(self.cb_container)
# Favorite Badge -------------------------------------------------------
self.favorite_badge = QCheckBox()
self.favorite_badge.setObjectName("favBadge")
self.favorite_badge.setToolTip("Favorite")
self.favorite_badge.setStyleSheet(
f"QCheckBox::indicator{{width: {check_size}px;height: {check_size}px;}}"
f"QCheckBox::indicator::unchecked{{image: url(:/images/star_icon_empty_128.png)}}"
f"QCheckBox::indicator::checked{{image: url(:/images/star_icon_filled_128.png)}}"
# f'QCheckBox{{background-color:yellow;}}'
)
self.favorite_badge.setMinimumSize(check_size, check_size)
self.favorite_badge.setMaximumSize(check_size, check_size)
self.favorite_badge.stateChanged.connect(
lambda x=self.favorite_badge.isChecked(): self.on_favorite_check(bool(x))
)
self.badge_active: dict[BadgeType, bool] = {
BadgeType.FAVORITE: False,
BadgeType.ARCHIVED: False,
}
# self.fav_badge.setContentsMargins(0,0,0,0)
# tr_layout.addWidget(self.fav_badge)
# root_layout.addWidget(self.fav_badge, 0, 2)
self.cb_layout.addWidget(self.favorite_badge)
self.favorite_badge.setHidden(True)
self.badges: dict[BadgeType, QCheckBox] = {}
badge_icons = {
BadgeType.FAVORITE: (
":/images/star_icon_empty_128.png",
":/images/star_icon_filled_128.png",
),
BadgeType.ARCHIVED: (
":/images/box_icon_empty_128.png",
":/images/box_icon_filled_128.png",
),
}
for badge_type in BadgeType:
icon_empty, icon_checked = badge_icons[badge_type]
badge = QCheckBox()
badge.setObjectName(badge_type.name)
badge.setToolTip(badge_type.value)
badge.setStyleSheet(
f"QCheckBox::indicator{{width: {check_size}px;height: {check_size}px;}}"
f"QCheckBox::indicator::unchecked{{image: url({icon_empty})}}"
f"QCheckBox::indicator::checked{{image: url({icon_checked})}}"
)
badge.setMinimumSize(check_size, check_size)
badge.setMaximumSize(check_size, check_size)
badge.setHidden(True)
# Archive Badge --------------------------------------------------------
self.archived_badge = QCheckBox()
self.archived_badge.setObjectName("archiveBadge")
self.archived_badge.setToolTip("Archive")
self.archived_badge.setStyleSheet(
f"QCheckBox::indicator{{width: {check_size}px;height: {check_size}px;}}"
f"QCheckBox::indicator::unchecked{{image: url(:/images/box_icon_empty_128.png)}}"
f"QCheckBox::indicator::checked{{image: url(:/images/box_icon_filled_128.png)}}"
# f'QCheckBox{{background-color:red;}}'
)
self.archived_badge.setMinimumSize(check_size, check_size)
self.archived_badge.setMaximumSize(check_size, check_size)
# self.archived_badge.clicked.connect(lambda x: self.assign_archived(x))
self.archived_badge.stateChanged.connect(
lambda x=self.archived_badge.isChecked(): self.on_archived_check(bool(x))
)
badge.stateChanged.connect(lambda x, bt=badge_type: self.on_badge_check(bt))
# tr_layout.addWidget(self.archive_badge)
self.cb_layout.addWidget(self.archived_badge)
self.archived_badge.setHidden(True)
# root_layout.addWidget(self.archive_badge, 0, 2)
# self.dumpObjectTree()
self.badges[badge_type] = badge
self.cb_layout.addWidget(badge)
self.set_mode(mode)
def set_mode(self, mode: Optional[ItemType]) -> None:
@property
def is_favorite(self) -> bool:
return self.badge_active[BadgeType.FAVORITE]
@property
def is_archived(self):
return self.badge_active[BadgeType.ARCHIVED]
def set_mode(self, mode: ItemType | None) -> None:
if mode is None:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, True)
self.unsetCursor()
@@ -318,7 +342,6 @@ class ItemThumb(FlowWidget):
# self.check_badges.setHidden(True)
# self.ext_badge.setHidden(True)
# self.item_type_badge.setHidden(True)
pass
elif mode == ItemType.ENTRY and self.mode != ItemType.ENTRY:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
self.setCursor(Qt.CursorShape.PointingHandCursor)
@@ -349,11 +372,6 @@ class ItemThumb(FlowWidget):
self.mode = mode
# logging.info(f'Set Mode To: {self.mode}')
# def update_(self, thumb: QPixmap, size:QSize, ext:str, badges:list[QPixmap]) -> None:
# """Updates the ItemThumb's visuals."""
# if thumb:
# pass
def set_extension(self, ext: str) -> None:
if ext and ext.startswith(".") is False:
ext = "." + ext
@@ -376,8 +394,8 @@ class ItemThumb(FlowWidget):
self.ext_badge.setHidden(True)
self.count_badge.setHidden(True)
def update_thumb(self, timestamp: float, image: QPixmap = None):
"""Updates attributes of a thumbnail element."""
def update_thumb(self, timestamp: float, image: QPixmap | None = None):
"""Update attributes of a thumbnail element."""
# logging.info(f'[GUI] Updating Thumbnail for element {id(element)}: {id(image) if image else None}')
if timestamp > ItemThumb.update_cutoff:
self.thumb_button.setIcon(image if image else QPixmap())
@@ -386,11 +404,10 @@ class ItemThumb(FlowWidget):
def update_size(self, timestamp: float, size: QSize):
"""Updates attributes of a thumbnail element."""
# logging.info(f'[GUI] Updating size for element {id(element)}: {size.__str__()}')
if timestamp > ItemThumb.update_cutoff:
if self.thumb_button.iconSize != size:
self.thumb_button.setIconSize(size)
self.thumb_button.setMinimumSize(size)
self.thumb_button.setMaximumSize(size)
if timestamp > ItemThumb.update_cutoff and self.thumb_button.iconSize != size:
self.thumb_button.setIconSize(size)
self.thumb_button.setMinimumSize(size)
self.thumb_button.setMaximumSize(size)
def update_clickable(self, clickable: typing.Callable):
"""Updates attributes of a thumbnail element."""
@@ -401,58 +418,42 @@ class ItemThumb(FlowWidget):
self.thumb_button.clicked.connect(clickable)
self.thumb_button.is_connected = True
def update_badges(self):
if self.mode == ItemType.ENTRY:
# logging.info(f'[UPDATE BADGES] ENTRY: {self.lib.get_entry(self.item_id)}')
# logging.info(f'[UPDATE BADGES] ARCH: {self.lib.get_entry(self.item_id).has_tag(self.lib, 0)}, FAV: {self.lib.get_entry(self.item_id).has_tag(self.lib, 1)}')
self.assign_archived(
self.lib.get_entry(self.item_id).has_tag(self.lib, TAG_ARCHIVED)
)
self.assign_favorite(
self.lib.get_entry(self.item_id).has_tag(self.lib, TAG_FAVORITE)
)
def refresh_badge(self, entry: Entry | None = None):
if not entry:
if not self.item_id:
logger.error("missing both entry and item_id")
return None
def set_item_id(self, id: int):
"""
also sets the filepath for the file opener
"""
self.item_id = id
if id == -1:
return
entry = self.lib.get_entry(self.item_id)
filepath = self.lib.library_dir / entry.path / entry.filename
entry = self.lib.get_entry(self.item_id)
if not entry:
logger.error("Entry not found", item_id=self.item_id)
return
self.assign_badge(BadgeType.ARCHIVED, entry.is_archived)
self.assign_badge(BadgeType.FAVORITE, entry.is_favorited)
def set_item_id(self, entry: Entry):
filepath = self.lib.library_dir / entry.path
self.opener.set_filepath(filepath)
self.item_id = entry.id
def assign_favorite(self, value: bool):
# Switching mode to None to bypass mode-specific operations when the
# checkbox's state changes.
def assign_badge(self, badge_type: BadgeType, value: bool) -> None:
mode = self.mode
# blank mode to avoid recursive badge updates
self.mode = None
self.isFavorite = value
self.favorite_badge.setChecked(value)
if not self.thumb_button.underMouse():
self.favorite_badge.setHidden(not self.isFavorite)
self.mode = mode
badge = self.badges[badge_type]
self.badge_active[badge_type] = value
if badge.isChecked() != value:
badge.setChecked(value)
badge.setHidden(not value)
def assign_archived(self, value: bool):
# Switching mode to None to bypass mode-specific operations when the
# checkbox's state changes.
mode = self.mode
self.mode = None
self.isArchived = value
self.archived_badge.setChecked(value)
if not self.thumb_button.underMouse():
self.archived_badge.setHidden(not self.isArchived)
self.mode = mode
def show_check_badges(self, show: bool):
if self.mode != ItemType.TAG_GROUP:
self.favorite_badge.setHidden(
True if (not show and not self.isFavorite) else False
)
self.archived_badge.setHidden(
True if (not show and not self.isArchived) else False
)
for badge_type, badge in self.badges.items():
is_hidden = not (show or self.badge_active[badge_type])
badge.setHidden(is_hidden)
def enterEvent(self, event: QEnterEvent) -> None:
self.show_check_badges(True)
@@ -462,40 +463,55 @@ class ItemThumb(FlowWidget):
self.show_check_badges(False)
return super().leaveEvent(event)
def on_archived_check(self, toggle_value: bool):
if self.mode == ItemType.ENTRY:
self.isArchived = toggle_value
self.toggle_item_tag(toggle_value, TAG_ARCHIVED)
@badge_update_lock
def on_badge_check(self, badge_type: BadgeType):
if self.mode is None:
return
def on_favorite_check(self, toggle_value: bool):
if self.mode == ItemType.ENTRY:
self.isFavorite = toggle_value
self.toggle_item_tag(toggle_value, TAG_FAVORITE)
toggle_value = self.badges[badge_type].isChecked()
def toggle_item_tag(self, toggle_value: bool, tag_id: int):
def toggle_tag(entry: Entry):
if toggle_value:
self.favorite_badge.setHidden(False)
entry.add_tag(
self.panel.driver.lib,
tag_id,
field_id=FieldID.META_TAGS,
field_index=-1,
)
else:
entry.remove_tag(self.panel.driver.lib, tag_id)
self.badge_active[badge_type] = toggle_value
tag_id = BADGE_TAGS[badge_type]
# Is the badge a part of the selection?
if (ItemType.ENTRY, self.item_id) in self.panel.driver.selected:
# Yes, add chosen tag to all selected.
for _, item_id in self.panel.driver.selected:
entry = self.lib.get_entry(item_id)
toggle_tag(entry)
# check if current item is selected. if so, update all selected items
if self.grid_idx in self.driver.selected:
update_items = self.driver.selected
else:
# No, add tag to the entry this badge is on.
entry = self.lib.get_entry(self.item_id)
toggle_tag(entry)
update_items = [self.grid_idx]
if self.panel.isOpen:
self.panel.update_widgets()
self.panel.driver.update_badges()
for idx in update_items:
entry = self.driver.frame_content[idx]
self.toggle_item_tag(
entry, toggle_value, tag_id, _FieldID.TAGS_META.name, True
)
# update the entry
self.driver.frame_content[idx] = self.lib.search_library(
FilterState(id=entry.id)
)[1][0]
self.driver.update_badges(update_items)
def toggle_item_tag(
self,
entry: Entry,
toggle_value: bool,
tag_id: int,
field_key: str,
create_field: bool = False,
):
logger.info(
"toggle_item_tag",
entry_id=entry.id,
toggle_value=toggle_value,
tag_id=tag_id,
field_key=field_key,
)
tag = self.lib.get_tag(tag_id)
if toggle_value:
self.lib.add_field_tag(entry, tag, field_key, create_field)
else:
self.lib.remove_field_tag(entry, tag.id, field_key)
if self.driver.preview_panel.is_open:
self.driver.preview_panel.update_widgets()

View File

@@ -24,7 +24,7 @@ logging.basicConfig(format="%(message)s", level=logging.INFO)
class LandingWidget(QWidget):
def __init__(self, driver: "QtDriver", pixel_ratio: float):
super().__init__()
self.driver: "QtDriver" = driver
self.driver = driver
self.logo_label: ClickableLabel = ClickableLabel()
self._pixel_ratio: float = pixel_ratio
self._logo_width: int = int(480 * pixel_ratio)
@@ -56,7 +56,6 @@ class LandingWidget(QWidget):
self.logo_special_anim.setDuration(500)
# Create "Open/Create Library" button ----------------------------------
open_shortcut_text: str = ""
if sys.platform == "darwin":
open_shortcut_text = "(⌘+O)"
else:

View File

@@ -2,7 +2,6 @@
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
from types import FunctionType
from typing import Callable
from PySide6.QtCore import Signal, Qt
@@ -16,12 +15,11 @@ class PanelModal(QWidget):
# figure out what you want from this.
def __init__(
self,
widget: "PanelWidget",
widget,
title: str,
window_title: str,
done_callback: Callable = None,
# cancel_callback:FunctionType=None,
save_callback: Callable = None,
done_callback: Callable | None = None,
save_callback: Callable | None = None,
has_save: bool = False,
):
# [Done]
@@ -76,10 +74,12 @@ class PanelModal(QWidget):
if done_callback:
self.save_button.clicked.connect(done_callback)
if save_callback:
self.save_button.clicked.connect(
lambda: save_callback(widget.get_content())
)
self.button_layout.addWidget(self.save_button)
# trigger save button actions when pressing enter in the widget

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@
import math
import os
from types import FunctionType
from pathlib import Path
@@ -13,15 +12,10 @@ from PySide6.QtCore import Signal, Qt, QEvent
from PySide6.QtGui import QEnterEvent, QAction
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton
from src.core.library import Library, Tag
from src.core.library import Tag
from src.core.palette import ColorType, get_tag_color
ERROR = f"[ERROR]"
WARNING = f"[WARNING]"
INFO = f"[INFO]"
class TagWidget(QWidget):
edit_icon_128: Image.Image = Image.open(
str(Path(__file__).parents[3] / "resources/qt/images/edit_icon_128.png")
@@ -33,7 +27,6 @@ class TagWidget(QWidget):
def __init__(
self,
library: Library,
tag: Tag,
has_edit: bool,
has_remove: bool,
@@ -42,10 +35,9 @@ class TagWidget(QWidget):
on_edit_callback: FunctionType = None,
) -> None:
super().__init__()
self.lib = library
self.tag = tag
self.has_edit: bool = has_edit
self.has_remove: bool = has_remove
self.has_edit = has_edit
self.has_remove = has_remove
# self.bg_label = QLabel()
# self.setStyleSheet('background-color:blue;')
@@ -57,7 +49,7 @@ class TagWidget(QWidget):
self.bg_button = QPushButton(self)
self.bg_button.setFlat(True)
self.bg_button.setText(tag.display_name(self.lib).replace("&", "&&"))
self.bg_button.setText(tag.name)
if has_edit:
edit_action = QAction("Edit", self)
edit_action.triggered.connect(on_edit_callback)
@@ -65,13 +57,8 @@ class TagWidget(QWidget):
self.bg_button.addAction(edit_action)
# if on_click_callback:
self.bg_button.setContextMenuPolicy(Qt.ContextMenuPolicy.ActionsContextMenu)
# if has_remove:
# remove_action = QAction('Remove', self)
# # remove_action.triggered.connect(on_remove_callback)
# remove_action.triggered.connect(self.on_remove.emit())
# self.bg_button.addAction(remove_action)
search_for_tag_action = QAction("Search for Tag", self)
# search_for_tag_action.triggered.connect(on_click_callback)
search_for_tag_action.triggered.connect(self.on_click.emit)
self.bg_button.addAction(search_for_tag_action)
add_to_search_action = QAction("Add to Search", self)
@@ -106,7 +93,7 @@ class TagWidget(QWidget):
f"border-color:{get_tag_color(ColorType.BORDER, tag.color)};"
f"border-radius: 6px;"
f"border-style:solid;"
f"border-width: {math.ceil(1*self.devicePixelRatio())}px;"
f"border-width: {math.ceil(self.devicePixelRatio())}px;"
# f'border-top:2px solid {get_tag_color(ColorType.LIGHT_ACCENT, tag.color)};'
# f'border-bottom:2px solid {get_tag_color(ColorType.BORDER, tag.color)};'
# f'border-left:2px solid {get_tag_color(ColorType.BORDER, tag.color)};'
@@ -167,35 +154,6 @@ class TagWidget(QWidget):
# self.remove_button.clicked.connect(on_remove_callback)
self.remove_button.clicked.connect(self.on_remove.emit)
# NOTE: No more edit button! Just make it a right-click option.
# self.edit_button = QPushButton(self)
# self.edit_button.setFlat(True)
# self.edit_button.setText('Edit')
# self.edit_button.setIcon(QPixmap.fromImage(ImageQt.ImageQt(self.edit_icon_128)))
# self.edit_button.setIconSize(QSize(14,14))
# self.edit_button.setHidden(True)
# self.edit_button.setStyleSheet(f'color: {color};'
# f"background: {'black' if color not in ['black', 'gray', 'dark gray', 'cool gray', 'warm gray', 'blue', 'purple', 'violet'] else 'white'};"
# # f"color: {'black' if color not in ['black', 'gray', 'dark gray', 'cool gray', 'warm gray', 'blue', 'purple', 'violet'] else 'white'};"
# f"border-color: {'black' if color not in ['black', 'gray', 'dark gray', 'cool gray', 'warm gray', 'blue', 'purple', 'violet'] else 'white'};"
# f'font-weight: 600;'
# # f"border-color:{'black' if color not in [
# # 'black', 'gray', 'dark gray',
# # 'cool gray', 'warm gray', 'blue',
# # 'purple', 'violet'] else 'white'};"
# # f'QPushButton{{border-image: url(:/images/edit_icon_128.png);}}'
# # f'QPushButton{{border-image: url(:/images/edit_icon_128.png);}}'
# f'border-radius: 4px;'
# # f'border-style:solid;'
# # f'border-width:1px;'
# f'padding-top: 1.5px;'
# f'padding-right: 4px;'
# f'padding-bottom: 3px;'
# f'padding-left: 4px;'
# f'font-size: 14px')
# self.edit_button.setMinimumSize(18,18)
# # self.edit_button.setMaximumSize(18,18)
# self.inner_layout.addWidget(self.edit_button)
if has_remove:
self.inner_layout.addWidget(self.remove_button)
@@ -209,32 +167,6 @@ class TagWidget(QWidget):
# self.setMinimumSize(50,20)
# def set_name(self, name:str):
# self.bg_label.setText(str)
# def on_remove(self):
# if self.item and self.item[0] == ItemType.ENTRY:
# if self.field_index >= 0:
# self.lib.get_entry(self.item[1]).remove_tag(self.tag.id, self.field_index)
# else:
# self.lib.get_entry(self.item[1]).remove_tag(self.tag.id)
# def set_click(self, callback):
# try:
# self.bg_button.clicked.disconnect()
# except RuntimeError:
# pass
# if callback:
# self.bg_button.clicked.connect(callback)
# def set_click(self, function):
# try:
# self.bg.clicked.disconnect()
# except RuntimeError:
# pass
# # self.bg.clicked.connect(lambda checked=False, filepath=filepath: open_file(filepath))
# # self.bg.clicked.connect(function)
def enterEvent(self, event: QEnterEvent) -> None:
if self.has_remove:
self.remove_button.setHidden(False)

View File

@@ -3,15 +3,17 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import math
import typing
import structlog
from PySide6.QtCore import Signal, Qt
from PySide6.QtWidgets import QPushButton
from src.core.constants import TAG_FAVORITE, TAG_ARCHIVED
from src.core.library import Library, Tag
from src.core.library import Entry, Tag
from src.core.library.alchemy.enums import FilterState
from src.core.library.alchemy.fields import TagBoxField
from src.qt.flowlayout import FlowLayout
from src.qt.widgets.fields import FieldWidget
from src.qt.widgets.tag import TagWidget
@@ -19,30 +21,28 @@ from src.qt.widgets.panel import PanelModal
from src.qt.modals.build_tag import BuildTagPanel
from src.qt.modals.tag_search import TagSearchPanel
# Only import for type checking/autocompletion, will not be imported at runtime.
if typing.TYPE_CHECKING:
from src.qt.ts_qt import QtDriver
logger = structlog.get_logger(__name__)
class TagBoxWidget(FieldWidget):
updated = Signal()
error_occurred = Signal(Exception)
def __init__(
self,
item,
title,
field_index,
library: Library,
tags: list[int],
field: TagBoxField,
title: str,
driver: "QtDriver",
) -> None:
super().__init__(title)
# QObject.__init__(self)
self.item = item
self.lib = library
assert isinstance(field, TagBoxField), f"field is {type(field)}"
self.field = field
self.driver = driver # Used for creating tag click callbacks that search entries for that tag.
self.field_index = field_index
self.tags: list[int] = tags
self.setObjectName("tagBox")
self.base_layout = FlowLayout()
self.base_layout.setGridEfficiency(False)
@@ -62,11 +62,8 @@ class TagBoxWidget(FieldWidget):
f"border-color: #333333;"
f"border-radius: 6px;"
f"border-style:solid;"
f"border-width:{math.ceil(1*self.devicePixelRatio())}px;"
# f'padding-top: 1.5px;'
# f'padding-right: 4px;'
f"border-width:{math.ceil(self.devicePixelRatio())}px;"
f"padding-bottom: 5px;"
# f'padding-left: 4px;'
f"font-size: 20px;"
f"}}"
f"QPushButton::hover"
@@ -75,46 +72,44 @@ class TagBoxWidget(FieldWidget):
f"background: #555555;"
f"}}"
)
tsp = TagSearchPanel(self.lib)
tsp = TagSearchPanel(self.driver.lib)
tsp.tag_chosen.connect(lambda x: self.add_tag_callback(x))
self.add_modal = PanelModal(tsp, title, "Add Tags")
self.add_button.clicked.connect(
lambda: (tsp.update_tags(), self.add_modal.show()) # type: ignore
lambda: (
tsp.update_tags(),
self.add_modal.show(),
)
)
self.set_tags(tags)
# self.add_button.setHidden(True)
self.set_tags(field.tags)
def set_item(self, item):
self.item = item
def set_field(self, field: TagBoxField):
self.field = field
def set_tags(self, tags: list[int]):
logging.info(f"[TAG BOX WIDGET] SET TAGS: T:{tags} for E:{self.item.id}")
def set_tags(self, tags: typing.Iterable[Tag]):
is_recycled = False
if self.base_layout.itemAt(0):
# logging.info(type(self.base_layout.itemAt(0).widget()))
while self.base_layout.itemAt(0) and self.base_layout.itemAt(1):
# logging.info(f"I'm deleting { self.base_layout.itemAt(0).widget()}")
self.base_layout.takeAt(0).widget().deleteLater()
while self.base_layout.itemAt(0) and self.base_layout.itemAt(1):
self.base_layout.takeAt(0).widget().deleteLater()
is_recycled = True
for tag in tags:
# TODO: Remove space from the special search here (tag_id:x) once that system is finalized.
# tw = TagWidget(self.lib, self.lib.get_tag(tag), True, True,
# on_remove_callback=lambda checked=False, t=tag: (self.lib.get_entry(self.item.id).remove_tag(self.lib, t, self.field_index), self.updated.emit()),
# on_click_callback=lambda checked=False, q=f'tag_id: {tag}': (self.driver.main_window.searchField.setText(q), self.driver.filter_items(q)),
# on_edit_callback=lambda checked=False, t=tag: (self.edit_tag(t))
# )
tw = TagWidget(self.lib, self.lib.get_tag(tag), True, True)
tw.on_click.connect(
lambda checked=False, q=f"tag_id: {tag}": (
self.driver.main_window.searchField.setText(q),
self.driver.filter_items(q),
tag_widget = TagWidget(tag, True, True)
tag_widget.on_click.connect(
lambda tag_id=tag.id: (
self.driver.main_window.searchField.setText(f"tag_id:{tag_id}"),
self.driver.filter_items(FilterState(tag_id=tag_id)),
)
)
tw.on_remove.connect(lambda checked=False, t=tag: (self.remove_tag(t)))
tw.on_edit.connect(lambda checked=False, t=tag: (self.edit_tag(t)))
self.base_layout.addWidget(tw)
self.tags = tags
tag_widget.on_remove.connect(
lambda tag_id=tag.id: (
self.remove_tag(tag_id),
self.driver.preview_panel.update_widgets(),
)
)
tag_widget.on_edit.connect(lambda t=tag: self.edit_tag(t))
self.base_layout.addWidget(tag_widget)
# Move or add the '+' button.
if is_recycled:
@@ -127,62 +122,59 @@ class TagBoxWidget(FieldWidget):
if self.base_layout.itemAt(0) and not self.base_layout.itemAt(1):
self.base_layout.update()
def edit_tag(self, tag_id: int):
btp = BuildTagPanel(self.lib, tag_id)
# btp.on_edit.connect(lambda x: self.edit_tag_callback(x))
def edit_tag(self, tag: Tag):
assert isinstance(tag, Tag), f"tag is {type(tag)}"
build_tag_panel = BuildTagPanel(self.driver.lib, tag=tag)
self.edit_modal = PanelModal(
btp,
self.lib.get_tag(tag_id).display_name(self.lib),
build_tag_panel,
tag.name, # TODO - display name including subtags
"Edit Tag",
done_callback=(self.driver.preview_panel.update_widgets),
done_callback=self.driver.preview_panel.update_widgets,
has_save=True,
)
# self.edit_modal.widget.update_display_name.connect(lambda t: self.edit_modal.title_widget.setText(t))
self.edit_modal.saved.connect(lambda: self.lib.update_tag(btp.build_tag()))
# TODO - this was update_tag()
self.edit_modal.saved.connect(
lambda: self.driver.lib.update_tag(
build_tag_panel.build_tag(),
subtag_ids=build_tag_panel.subtags,
)
)
# panel.tag_updated.connect(lambda tag: self.lib.update_tag(tag))
self.edit_modal.show()
def add_tag_callback(self, tag_id: int):
# self.base_layout.addWidget(TagWidget(self.lib, self.lib.get_tag(tag), True))
# self.tags.append(tag)
logging.info(
f"[TAG BOX WIDGET] ADD TAG CALLBACK: T:{tag_id} to E:{self.item.id}"
)
logging.info(f"[TAG BOX WIDGET] SELECTED T:{self.driver.selected}")
id: int = list(self.field.keys())[0] # type: ignore
for x in self.driver.selected:
self.driver.lib.get_entry(x[1]).add_tag(
self.driver.lib, tag_id, field_id=id, field_index=-1
)
self.updated.emit()
logger.info("add_tag_callback", tag_id=tag_id, selected=self.driver.selected)
tag = self.driver.lib.get_tag(tag_id=tag_id)
for idx in self.driver.selected:
entry: Entry = self.driver.frame_content[idx]
if not self.driver.lib.add_field_tag(entry, tag, self.field.type_key):
# TODO - add some visible error
self.error_occurred.emit(Exception("Failed to add tag"))
self.updated.emit()
if tag_id in (TAG_FAVORITE, TAG_ARCHIVED):
self.driver.update_badges()
# if type((x[0]) == ThumbButton):
# # TODO: Remove space from the special search here (tag_id:x) once that system is finalized.
# logging.info(f'I want to add tag ID {tag_id} to entry {self.item.filename}')
# self.updated.emit()
# if tag_id not in self.tags:
# self.tags.append(tag_id)
# self.set_tags(self.tags)
# elif type((x[0]) == ThumbButton):
def edit_tag_callback(self, tag: Tag):
self.lib.update_tag(tag)
self.driver.lib.update_tag(tag)
def remove_tag(self, tag_id: int):
logging.info(f"[TAG BOX WIDGET] SELECTED T:{self.driver.selected}")
id: int = list(self.field.keys())[0] # type: ignore
for x in self.driver.selected:
index = self.driver.lib.get_field_index_in_entry(
self.driver.lib.get_entry(x[1]), id
)
self.driver.lib.get_entry(x[1]).remove_tag(
self.driver.lib, tag_id, field_index=index[0]
)
logger.info(
"remove_tag",
selected=self.driver.selected,
field_type=self.field.type,
)
for grid_idx in self.driver.selected:
entry = self.driver.frame_content[grid_idx]
self.driver.lib.remove_field_tag(entry, tag_id, self.field.type_key)
self.updated.emit()
if tag_id in (TAG_FAVORITE, TAG_ARCHIVED):
self.driver.update_badges()
# def show_add_button(self, value:bool):
# self.add_button.setHidden(not value)

View File

@@ -3,7 +3,6 @@
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
import logging
import math
from pathlib import Path
@@ -22,6 +21,7 @@ from PIL import (
from PIL.Image import DecompressionBombError
from PySide6.QtCore import QObject, Signal, QSize
from PySide6.QtGui import QPixmap
from src.qt.helpers.gradient import four_corner_gradient_background
from src.core.constants import (
PLAINTEXT_TYPES,
@@ -29,15 +29,15 @@ from src.core.constants import (
IMAGE_TYPES,
RAW_IMAGE_TYPES,
)
import structlog
from src.core.utils.encoding import detect_char_encoding
ImageFile.LOAD_TRUNCATED_IMAGES = True
ERROR = "[ERROR]"
WARNING = "[WARNING]"
INFO = "[INFO]"
logging.basicConfig(format="%(message)s", level=logging.INFO)
logger = structlog.get_logger(__name__)
register_heif_opener()
register_avif_opener()
@@ -95,7 +95,10 @@ class ThumbRenderer(QObject):
gradient=False,
update_on_ratio_change=False,
):
"""Internal renderer. Renders an entry/element thumbnail for the GUI."""
"""Internal renderer. Render an entry/element thumbnail for the GUI."""
logger.debug("rendering thumbnail", path=filepath)
image: Image.Image = None
pixmap: QPixmap = None
final: Image.Image = None
@@ -133,8 +136,8 @@ class ThumbRenderer(QObject):
image = ImageOps.exif_transpose(image)
except DecompressionBombError as e:
logging.info(
f"[ThumbRenderer]{WARNING} Couldn't Render thumbnail for {_filepath.name} ({type(e).__name__})"
logger.error(
"Couldn't Render thumbnail", filepath=filepath, error=e
)
elif _filepath.suffix.lower() in RAW_IMAGE_TYPES:
@@ -148,15 +151,16 @@ class ThumbRenderer(QObject):
decoder_name="raw",
)
except DecompressionBombError as e:
logging.info(
f"[ThumbRenderer]{WARNING} Couldn't Render thumbnail for {_filepath.name} ({type(e).__name__})"
logger.error(
"Couldn't Render thumbnail", filepath=filepath, error=e
)
except (
rawpy._rawpy.LibRawIOError,
rawpy._rawpy.LibRawFileUnsupportedError,
) as e:
logging.info(
f"[ThumbRenderer]{ERROR} Couldn't Render thumbnail for raw image {_filepath.name} ({type(e).__name__})"
logger.error(
"Couldn't Render thumbnail", filepath=filepath, error=e
)
# Videos =======================================================
@@ -179,7 +183,7 @@ class ThumbRenderer(QObject):
# Plain Text ===================================================
elif _filepath.suffix.lower() in PLAINTEXT_TYPES:
encoding = detect_char_encoding(_filepath)
with open(_filepath, "r", encoding=encoding) as text_file:
with open(_filepath, encoding=encoding) as text_file:
text = text_file.read(256)
bg = Image.new("RGB", (256, 256), color="#1e1e1e")
draw = ImageDraw.Draw(bg)
@@ -268,9 +272,10 @@ class ThumbRenderer(QObject):
UnicodeDecodeError,
) as e:
if e is not UnicodeDecodeError:
logging.info(
f"[ThumbRenderer]{ERROR}: Couldn't render thumbnail for {_filepath.name} ({type(e).__name__})"
logger.error(
"Couldn't Render thumbnail", filepath=filepath, error=e
)
if update_on_ratio_change:
self.updated_ratio.emit(1)
final = ThumbRenderer.thumb_broken_512.resize(

View File

@@ -3,7 +3,6 @@
import logging
from pathlib import Path
import typing
from PySide6.QtCore import (
@@ -122,7 +121,7 @@ class VideoPlayer(QGraphicsView):
autoplay_action.setCheckable(True)
self.addAction(autoplay_action)
autoplay_action.setChecked(
self.driver.settings.value(SettingItems.AUTOPLAY, True, bool) # type: ignore
bool(self.driver.settings.value(SettingItems.AUTOPLAY, True, type=bool))
)
autoplay_action.triggered.connect(lambda: self.toggleAutoplay())
self.autoplay = autoplay_action
@@ -176,37 +175,30 @@ class VideoPlayer(QGraphicsView):
def eventFilter(self, obj: QObject, event: QEvent) -> bool:
# This chunk of code is for the video controls.
if (
obj == self.play_pause
and event.type() == QEvent.Type.MouseButtonPress
event.type() == QEvent.Type.MouseButtonPress
and event.button() == Qt.MouseButton.LeftButton # type: ignore
):
if self.player.hasVideo():
if obj == self.play_pause and self.player.hasVideo():
self.pauseToggle()
if (
obj == self.mute_button
and event.type() == QEvent.Type.MouseButtonPress
and event.button() == Qt.MouseButton.LeftButton # type: ignore
):
if self.player.hasAudio():
elif obj == self.mute_button and self.player.hasAudio():
self.muteToggle()
if (
obj == self.video_preview
and event.type() == QEvent.Type.GraphicsSceneHoverEnter
or event.type() == QEvent.Type.HoverEnter
):
if self.video_preview.isUnderMouse():
self.underMouse()
self.hover_fix_timer.start(10)
elif (
obj == self.video_preview
and event.type() == QEvent.Type.GraphicsSceneHoverLeave
or event.type() == QEvent.Type.HoverLeave
):
if not self.video_preview.isUnderMouse():
elif obj == self.video_preview:
if event.type() in (
QEvent.Type.GraphicsSceneHoverEnter,
QEvent.Type.HoverEnter,
):
if self.video_preview.isUnderMouse():
self.underMouse()
self.hover_fix_timer.start(10)
elif (
event.type()
in (QEvent.Type.GraphicsSceneHoverLeave, QEvent.Type.HoverLeave)
and not self.video_preview.isUnderMouse()
):
self.hover_fix_timer.stop()
self.releaseMouse()
return super().eventFilter(obj, event)
def checkIfStillHovered(self) -> None:
@@ -334,14 +326,13 @@ class VideoPlayer(QGraphicsView):
int(self.video_preview.size().height()),
)
)
return
class VideoPreview(QGraphicsVideoItem):
def boundingRect(self):
return QRectF(0, 0, self.size().width(), self.size().height())
def paint(self, painter, option, widget):
def paint(self, painter, option, widget=None) -> None:
# painter.brush().setColor(QColor(0, 0, 0, 255))
# You can set any shape you want here.
# RoundedRect is the standard rectangle with rounded corners.

33
tagstudio/tag_studio.py Normal file → Executable file
View File

@@ -1,16 +1,23 @@
#!/usr/bin/env python
# Copyright (C) 2024 Travis Abendshien (CyanVoxel).
# Licensed under the GPL-3.0 License.
# Created for TagStudio: https://github.com/CyanVoxel/TagStudio
"""TagStudio launcher."""
from src.core.ts_core import TagStudioCore
from src.cli.ts_cli import CliDriver # type: ignore
import structlog
import logging
from src.qt.ts_qt import QtDriver
import argparse
import traceback
structlog.configure(
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
)
def main():
# appid = "cyanvoxel.tagstudio.9"
# ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(appid)
@@ -48,27 +55,11 @@ def main():
type=str,
help="User interface option for TagStudio. Options: qt, cli (Default: qt)",
)
parser.add_argument(
"--ci",
action=argparse.BooleanOptionalAction,
help="Exit the application after checking it starts without any problem. Meant for CI check.",
)
args = parser.parse_args()
from src.core.library import alchemy as backend
core = TagStudioCore() # The TagStudio Core instance. UI agnostic.
driver = None # The UI driver instance.
ui_name: str = "unknown" # Display name for the UI, used in logs.
# Driver selection based on parameters.
if args.ui and args.ui == "qt":
driver = QtDriver(core, args)
ui_name = "Qt"
elif args.ui and args.ui == "cli":
driver = CliDriver(core, args)
ui_name = "CLI"
else:
driver = QtDriver(core, args)
ui_name = "Qt"
driver = QtDriver(backend, args)
ui_name = "Qt"
# Run the chosen frontend driver.
try:

View File

@@ -1,42 +1,131 @@
import sys
import pathlib
from tempfile import TemporaryDirectory
from unittest.mock import patch, Mock
import pytest
from syrupy.extensions.json import JSONSnapshotExtension
CWD = pathlib.Path(__file__).parent
# this needs to be above `src` imports
sys.path.insert(0, str(CWD.parent))
from src.core.library import Tag, Library
from src.core.library import Library, Tag, Entry
from src.core.library.alchemy.enums import TagColor
from src.core.library.alchemy.fields import TagBoxField, _FieldID
from src.core.library import alchemy as backend
from src.qt.ts_qt import QtDriver
@pytest.fixture
def test_tag():
yield Tag(
id=1,
name="Tag Name",
shorthand="TN",
aliases=["First A", "Second A"],
subtags_ids=[2, 3, 4],
color="",
)
def cwd():
return CWD
@pytest.fixture
def test_library():
lib_dir = CWD / "fixtures" / "library"
def library(request):
# when no param is passed, use the default
library_path = "/tmp/"
if hasattr(request, "param"):
if isinstance(request.param, TemporaryDirectory):
library_path = request.param.name
else:
library_path = request.param
lib = Library()
ret_code = lib.open_library(lib_dir)
assert ret_code == 1
# create files for the entries
for entry in lib.entries:
(lib_dir / entry.filename).touch()
lib.open_library(library_path, ":memory:")
assert lib.folder
tag = Tag(
name="foo",
color=TagColor.RED,
)
assert lib.add_tag(tag)
subtag = Tag(
name="subbar",
color=TagColor.YELLOW,
)
tag2 = Tag(
name="bar",
color=TagColor.BLUE,
subtags={subtag},
)
# default item with deterministic name
entry = Entry(
folder=lib.folder,
path=pathlib.Path("foo.txt"),
fields=lib.default_fields,
)
entry.tag_box_fields = [
TagBoxField(type_key=_FieldID.TAGS.name, tags={tag}, position=0),
TagBoxField(
type_key=_FieldID.TAGS_META.name,
position=0,
),
]
entry2 = Entry(
folder=lib.folder,
path=pathlib.Path("one/two/bar.md"),
fields=lib.default_fields,
)
entry2.tag_box_fields = [
TagBoxField(
tags={tag2},
type_key=_FieldID.TAGS_META.name,
position=0,
),
]
assert lib.add_entries([entry, entry2])
assert len(lib.tags) == 5
yield lib
@pytest.fixture
def snapshot_json(snapshot):
return snapshot.with_defaults(extension_class=JSONSnapshotExtension)
def entry_min(library):
yield next(library.get_entries())
@pytest.fixture
def entry_full(library):
yield next(library.get_entries(with_joins=True))
@pytest.fixture
def qt_driver(qtbot, library):
with TemporaryDirectory() as tmp_dir:
class Args:
config_file = pathlib.Path(tmp_dir) / "tagstudio.ini"
open = pathlib.Path(tmp_dir)
ci = True
# patch CustomRunnable
with patch("src.qt.ts_qt.Consumer"), patch("src.qt.ts_qt.CustomRunnable"):
driver = QtDriver(backend, Args())
driver.main_window = Mock()
driver.preview_panel = Mock()
driver.flow_container = Mock()
driver.item_thumbs = []
driver.lib = library
# TODO - downsize this method and use it
# driver.start()
driver.frame_content = list(library.get_entries())
yield driver
@pytest.fixture
def generate_tag():
def inner(name, **kwargs):
params = dict(name=name, color=TagColor.RED) | kwargs
return Tag(**params)
yield inner

View File

@@ -1,6 +0,0 @@
[
[
"<ItemType.ENTRY: 0>",
2
]
]

View File

@@ -1,6 +0,0 @@
[
[
"<ItemType.ENTRY: 0>",
1
]
]

View File

@@ -1,4 +0,0 @@
[
"{'id': 1, 'filename': 'foo.txt', 'path': '.', 'fields': [{6: [1001]}]}",
"{'id': 2, 'filename': 'bar.txt', 'path': '.', 'fields': [{6: [1000]}]}"
]

View File

@@ -1,18 +0,0 @@
import pytest
def test_open_library(test_library, snapshot_json):
assert test_library.entries == snapshot_json
@pytest.mark.parametrize(
["query"],
[
("First",),
("Second",),
("--nomatch--",),
],
)
def test_library_search(test_library, query, snapshot_json):
res = test_library.search_library(query)
assert res == snapshot_json

View File

@@ -1,8 +0,0 @@
def test_subtag(test_tag):
test_tag.remove_subtag(2)
test_tag.remove_subtag(2)
test_tag.add_subtag(5)
# repeated add should not add the subtag
test_tag.add_subtag(5)
assert test_tag.subtag_ids == [3, 4, 5]

View File

@@ -1,69 +0,0 @@
{
"ts-version": "9.3.1",
"ext_list": [
".json",
".xmp",
".aae"
],
"is_exclude_list": true,
"tags": [
{
"id": 0,
"name": "Archived",
"aliases": [
"Archive"
],
"color": "Red"
},
{
"id": 1,
"name": "Favorite",
"aliases": [
"Favorited",
"Favorites"
],
"color": "Yellow"
},
{
"id": 1000,
"name": "first",
"shorthand": "first",
"color": "magenta"
},
{
"id": 1001,
"name": "second",
"shorthand": "second",
"color": "blue"
}
],
"collations": [],
"fields": [],
"macros": [],
"entries": [
{
"id": 1,
"filename": "foo.txt",
"path": ".",
"fields": [
{
"6": [
1001
]
}
]
},
{
"id": 2,
"filename": "bar.txt",
"path": ".",
"fields": [
{
"6": [
1000
]
}
]
}
]
}

View File

@@ -0,0 +1,10 @@
<results>
<group>
<file path="/tmp/bar/foo.txt" words="" is_ref="n" marked="n"/>
<file path="/tmp/foo.txt" words="" is_ref="n" marked="n"/>
<file path="/tmp/foo/foo.txt" words="" is_ref="n" marked="n"/>
<match first="1" second="0" percentage="100"/>
<match first="0" second="2" percentage="100"/>
<match first="1" second="2" percentage="100"/>
</group>
</results>

View File

@@ -0,0 +1,10 @@
{
"tags": [
"ng_tag",
"ng_tag2"
],
"date": "2024-01-02",
"description": "NG description",
"user": "NG artist",
"post_url": "https://ng.com"
}

View File

@@ -0,0 +1,35 @@
import pathlib
from src.core.library import Entry
from src.core.utils.dupe_files import DupeRegistry
CWD = pathlib.Path(__file__).parent
def test_refresh_dupe_files(library):
entry = Entry(
folder=library.folder,
path=pathlib.Path("bar/foo.txt"),
fields=library.default_fields,
)
entry2 = Entry(
folder=library.folder,
path=pathlib.Path("foo/foo.txt"),
fields=library.default_fields,
)
library.add_entries([entry, entry2])
registry = DupeRegistry(library=library)
dupe_file_path = CWD.parent / "fixtures" / "result.dupeguru"
registry.refresh_dupe_files(dupe_file_path)
assert len(registry.groups) == 1
paths = [entry.path for entry in registry.groups[0]]
assert paths == [
pathlib.Path("bar/foo.txt"),
pathlib.Path("foo.txt"),
pathlib.Path("foo/foo.txt"),
]

View File

@@ -0,0 +1,9 @@
from src.qt.modals.folders_to_tags import folders_to_tags
def test_folders_to_tags(library):
folders_to_tags(library)
entry = [
x for x in library.get_entries(with_joins=True) if "bar.md" in str(x.path)
][0]
assert {x.name for x in entry.tags} == {"two", "bar"}

View File

@@ -0,0 +1,31 @@
import pathlib
from tempfile import TemporaryDirectory
import pytest
from src.core.library import Library
from src.core.library.alchemy.enums import FilterState
from src.core.utils.missing_files import MissingRegistry
CWD = pathlib.Path(__file__).parent
@pytest.mark.parametrize("library", [TemporaryDirectory()], indirect=True)
def test_refresh_missing_files(library: Library):
registry = MissingRegistry(library=library)
# touch the file `one/two/bar.md` but in wrong location to simulate a moved file
(library.library_dir / "bar.md").touch()
# no files actually exist, so it should return all entries
assert list(registry.refresh_missing_files()) == [0, 1]
# neither of the library entries exist
assert len(registry.missing_files) == 2
# iterate through two files
assert list(registry.fix_missing_files()) == [1, 2]
# `bar.md` should be relinked to new correct path
_, entries = library.search_library(FilterState(path="bar.md"))
assert entries[0].path == pathlib.Path("bar.md")

View File

@@ -0,0 +1,19 @@
import pathlib
from tempfile import TemporaryDirectory
import pytest
from src.core.utils.refresh_dir import RefreshDirTracker
CWD = pathlib.Path(__file__).parent
@pytest.mark.parametrize("library", [TemporaryDirectory()], indirect=True)
def test_refresh_new_files(library):
registry = RefreshDirTracker(library=library)
# touch new files to simulate new files
(library.library_dir / "foo.md").touch()
assert not list(registry.refresh_dir())
assert registry.files_not_in_library == [pathlib.Path("foo.md")]

View File

@@ -0,0 +1,37 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from src.core.enums import MacroID
from src.core.library.alchemy.fields import _FieldID
@pytest.mark.parametrize("library", [TemporaryDirectory()], indirect=True)
def test_sidecar_macro(qt_driver, library, cwd, entry_full):
entry_full.path = Path("newgrounds/foo.txt")
fixture = cwd / "fixtures/sidecar_newgrounds.json"
dst = library.library_dir / "newgrounds" / (entry_full.path.stem + ".json")
dst.parent.mkdir()
shutil.copy(fixture, dst)
qt_driver.frame_content = [entry_full]
qt_driver.run_macro(MacroID.SIDECAR, 0)
entry = next(library.get_entries(with_joins=True))
new_fields = (
(_FieldID.DESCRIPTION.name, "NG description"),
(_FieldID.ARTIST.name, "NG artist"),
(_FieldID.SOURCE.name, "https://ng.com"),
(_FieldID.TAGS.name, None),
)
found = [(field.type.key, field.value) for field in entry.fields]
# `new_fields` should be subset of `found`
for field in new_fields:
assert field in found, f"Field not found: {field} / {found}"
expected_tags = {"ng_tag", "ng_tag2"}
assert {x.name in expected_tags for x in entry.tags}

View File

@@ -0,0 +1,4 @@
# serializer version: 1
# name: test_generate_preview_data
BranchData(dirs={'two': BranchData(dirs={}, files=['bar.md'], tag=<Tag ID: None Name: two>)}, files=[], tag=None)
# ---

View File

@@ -0,0 +1,111 @@
from pathlib import Path
from unittest.mock import Mock
from src.core.library import Entry
from src.core.library.alchemy.enums import FilterState
from src.core.library.json.library import ItemType
from src.qt.widgets.item_thumb import ItemThumb
def test_update_thumbs(qt_driver):
qt_driver.frame_content = [
Entry(
folder=qt_driver.lib.folder,
path=Path("/tmp/foo"),
fields=qt_driver.lib.default_fields,
)
]
qt_driver.item_thumbs = []
for i in range(3):
qt_driver.item_thumbs.append(
ItemThumb(
mode=ItemType.ENTRY,
library=qt_driver.lib,
driver=qt_driver,
thumb_size=(100, 100),
grid_idx=i,
)
)
qt_driver.update_thumbs()
for idx, thumb in enumerate(qt_driver.item_thumbs):
# only first item is visible
assert thumb.isVisible() == (idx == 0)
def test_select_item_bridge(qt_driver, entry_min):
# mock some props since we're not running `start()`
qt_driver.autofill_action = Mock()
qt_driver.sort_fields_action = Mock()
# set the content manually
qt_driver.frame_content = [entry_min] * 3
qt_driver.filter.page_size = 3
qt_driver._init_thumb_grid()
assert len(qt_driver.item_thumbs) == 3
# select first item
qt_driver.select_item(0, False, False)
assert qt_driver.selected == [0]
# add second item to selection
qt_driver.select_item(1, False, bridge=True)
assert qt_driver.selected == [0, 1]
# add third item to selection
qt_driver.select_item(2, False, bridge=True)
assert qt_driver.selected == [0, 1, 2]
# select third item only
qt_driver.select_item(2, False, bridge=False)
assert qt_driver.selected == [2]
qt_driver.select_item(0, False, bridge=True)
assert qt_driver.selected == [0, 1, 2]
def test_library_state_update(qt_driver):
# Given
for idx, entry in enumerate(qt_driver.lib.get_entries(with_joins=True)):
thumb = ItemThumb(ItemType.ENTRY, qt_driver.lib, qt_driver, (100, 100), idx)
qt_driver.item_thumbs.append(thumb)
qt_driver.frame_content.append(entry)
# no filter, both items are returned
qt_driver.filter_items()
assert len(qt_driver.frame_content) == 2
# filter by tag
state = FilterState(tag="foo", page_size=10)
qt_driver.filter_items(state)
assert qt_driver.filter.page_size == 10
assert len(qt_driver.frame_content) == 1
entry = qt_driver.frame_content[0]
assert list(entry.tags)[0].name == "foo"
# When state is not changed, previous one is still applied
qt_driver.filter_items()
assert qt_driver.filter.page_size == 10
assert len(qt_driver.frame_content) == 1
entry = qt_driver.frame_content[0]
assert list(entry.tags)[0].name == "foo"
# When state property is changed, previous one is overriden
state = FilterState(path="bar.md")
qt_driver.filter_items(state)
assert len(qt_driver.frame_content) == 1
entry = qt_driver.frame_content[0]
assert list(entry.tags)[0].name == "bar"
def test_close_library(qt_driver):
# Given
qt_driver.close_library()
# Then
assert len(qt_driver.frame_content) == 0
assert len(qt_driver.item_thumbs) == 0
assert qt_driver.selected == []

View File

@@ -0,0 +1,18 @@
from PySide6.QtCore import QRect
from PySide6.QtWidgets import QWidget, QPushButton
from src.qt.flowlayout import FlowLayout
def test_flow_layout_happy_path(qtbot):
class Window(QWidget):
def __init__(self):
super().__init__()
self.flow_layout = FlowLayout(self)
self.flow_layout.setGridEfficiency(True)
self.flow_layout.addWidget(QPushButton("Short"))
window = Window()
assert window.flow_layout.count()
assert window.flow_layout._do_layout(QRect(0, 0, 0, 0), False)

View File

@@ -0,0 +1,7 @@
from src.qt.modals.folders_to_tags import generate_preview_data
def test_generate_preview_data(library, snapshot):
preview = generate_preview_data(library)
assert preview == snapshot

View File

@@ -0,0 +1,19 @@
import pytest
from src.core.library import ItemType
from src.qt.widgets.item_thumb import ItemThumb, BadgeType
@pytest.mark.parametrize("new_value", (True, False))
def test_badge_visual_state(library, qt_driver, entry_min, new_value):
thumb = ItemThumb(ItemType.ENTRY, qt_driver.lib, qt_driver, (100, 100), 0)
qt_driver.frame_content = [entry_min]
qt_driver.selected = [0]
qt_driver.item_thumbs = [thumb]
thumb.badges[BadgeType.FAVORITE].setChecked(new_value)
assert thumb.badges[BadgeType.FAVORITE].isChecked() == new_value
# TODO
# assert thumb.favorite_badge.isHidden() == initial_state
assert thumb.is_favorite == new_value

View File

@@ -0,0 +1,122 @@
from pathlib import Path
from src.core.library import Entry
from src.core.library.alchemy.enums import FieldTypeEnum
from src.core.library.alchemy.fields import _FieldID, TextField
from src.qt.widgets.preview_panel import PreviewPanel
def test_update_widgets_not_selected(qt_driver, library):
qt_driver.frame_content = list(library.get_entries())
qt_driver.selected = []
panel = PreviewPanel(library, qt_driver)
panel.update_widgets()
assert panel.preview_img.isVisible()
assert panel.file_label.text() == "No Items Selected"
def test_update_widgets_single_selected(qt_driver, library):
qt_driver.frame_content = list(library.get_entries())
qt_driver.selected = [0]
panel = PreviewPanel(library, qt_driver)
panel.update_widgets()
assert panel.preview_img.isVisible()
def test_update_widgets_multiple_selected(qt_driver, library):
# entry with no tag fields
entry = Entry(
path=Path("test.txt"),
folder=library.folder,
fields=[TextField(type_key=_FieldID.TITLE.name, position=0)],
)
assert not entry.tag_box_fields
library.add_entries([entry])
assert library.entries_count == 3
qt_driver.frame_content = list(library.get_entries())
qt_driver.selected = [0, 1, 2]
panel = PreviewPanel(library, qt_driver)
panel.update_widgets()
assert {f.type_key for f in panel.common_fields} == {
_FieldID.TITLE.name,
}
assert {f.type_key for f in panel.mixed_fields} == {
_FieldID.TAGS.name,
_FieldID.TAGS_META.name,
}
def test_write_container_text_line(qt_driver, entry_full, library):
# Given
panel = PreviewPanel(library, qt_driver)
field = entry_full.text_fields[0]
assert len(entry_full.text_fields) == 1
assert field.type.type == FieldTypeEnum.TEXT_LINE
assert field.type.name == "Title"
# set any value
field.value = "foo"
panel.write_container(0, field)
panel.selected = [0]
assert len(panel.containers) == 1
container = panel.containers[0]
widget = container.get_inner_widget()
# test it's not "mixed data"
assert widget.text_label.text() == "foo"
# When update and submit modal
modal = panel.containers[0].modal
modal.widget.text_edit.setText("bar")
modal.save_button.click()
# Then reload entry
entry_full = next(library.get_entries(with_joins=True))
# the value was updated
assert entry_full.text_fields[0].value == "bar"
def test_remove_field(qt_driver, library):
# Given
panel = PreviewPanel(library, qt_driver)
entries = list(library.get_entries(with_joins=True))
qt_driver.frame_content = entries
# When second entry is selected
panel.selected = [1]
field = entries[1].text_fields[0]
panel.write_container(0, field)
panel.remove_field(field)
entries = list(library.get_entries(with_joins=True))
assert not entries[1].text_fields
def test_update_field(qt_driver, library, entry_full):
panel = PreviewPanel(library, qt_driver)
# select both entries
qt_driver.frame_content = list(library.get_entries())[:2]
qt_driver.selected = [0, 1]
panel.selected = [0, 1]
# update field
title_field = entry_full.text_fields[0]
panel.update_field(title_field, "meow")
for entry in library.get_entries(with_joins=True):
field = [x for x in entry.text_fields if x.type_key == title_field.type_key][0]
assert field.value == "meow"

View File

@@ -0,0 +1,24 @@
from src.core.library import Tag
from src.qt.modals.build_tag import BuildTagPanel
def test_tag_panel(qtbot, library):
panel = BuildTagPanel(library)
qtbot.addWidget(panel)
def test_add_tag_callback(qt_driver):
# Given
assert len(qt_driver.lib.tags) == 5
qt_driver.add_tag_action_callback()
# When
qt_driver.modal.widget.name_field.setText("xxx")
qt_driver.modal.widget.color_field.setCurrentIndex(1)
qt_driver.modal.saved.emit()
# Then
tags: set[Tag] = qt_driver.lib.tags
assert len(tags) == 6
assert "xxx" in {tag.name for tag in tags}

View File

@@ -0,0 +1,11 @@
from src.qt.modals.tag_search import TagSearchPanel
def test_update_tags(qtbot, library):
# Given
panel = TagSearchPanel(library)
qtbot.addWidget(panel)
# When
panel.update_tags()

View File

@@ -0,0 +1,117 @@
from unittest.mock import patch
from src.core.library.alchemy.fields import _FieldID
from src.qt.widgets.tag import TagWidget
from src.qt.widgets.tag_box import TagBoxWidget
from src.qt.modals.build_tag import BuildTagPanel
def test_tag_widget(qtbot, library, qt_driver):
# given
entry = next(library.get_entries(with_joins=True))
field = entry.tag_box_fields[0]
tag_widget = TagBoxWidget(field, "title", qt_driver)
qtbot.add_widget(tag_widget)
assert not tag_widget.add_modal.isVisible()
# when/then check no exception is raised
tag_widget.add_button.clicked.emit()
# check `tag_widget.add_modal` is visible
assert tag_widget.add_modal.isVisible()
def test_tag_widget_add_existing_raises(library, qt_driver, entry_full):
# Given
tag_field = [
f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name
][0]
assert len(entry_full.tags) == 1
tag = next(iter(entry_full.tags))
# When
tag_widget = TagBoxWidget(tag_field, "title", qt_driver)
tag_widget.driver.frame_content = [entry_full]
tag_widget.driver.selected = [0]
# Then
with patch.object(tag_widget, "error_occurred") as mocked:
tag_widget.add_modal.widget.tag_chosen.emit(tag.id)
assert mocked.emit.called
def test_tag_widget_add_new_pass(qtbot, library, qt_driver, generate_tag):
# Given
entry = next(library.get_entries(with_joins=True))
field = entry.tag_box_fields[0]
tag = generate_tag(name="new_tag")
library.add_tag(tag)
tag_widget = TagBoxWidget(field, "title", qt_driver)
qtbot.add_widget(tag_widget)
tag_widget.driver.selected = [0]
with patch.object(tag_widget, "error_occurred") as mocked:
# When
tag_widget.add_modal.widget.tag_chosen.emit(tag.id)
# Then
assert not mocked.emit.called
def test_tag_widget_remove(qtbot, qt_driver, library, entry_full):
tag = list(entry_full.tags)[0]
assert tag
assert entry_full.tag_box_fields
tag_field = [
f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name
][0]
tag_widget = TagBoxWidget(tag_field, "title", qt_driver)
tag_widget.driver.selected = [0]
qtbot.add_widget(tag_widget)
tag_widget = tag_widget.base_layout.itemAt(0).widget()
assert isinstance(tag_widget, TagWidget)
tag_widget.remove_button.clicked.emit()
entry = next(qt_driver.lib.get_entries(with_joins=True))
assert not entry.tag_box_fields[0].tags
def test_tag_widget_edit(qtbot, qt_driver, library, entry_full):
# Given
tag = list(entry_full.tags)[0]
assert tag
assert entry_full.tag_box_fields
tag_field = [
f for f in entry_full.tag_box_fields if f.type_key == _FieldID.TAGS.name
][0]
tag_box_widget = TagBoxWidget(tag_field, "title", qt_driver)
tag_box_widget.driver.selected = [0]
qtbot.add_widget(tag_box_widget)
tag_widget = tag_box_widget.base_layout.itemAt(0).widget()
assert isinstance(tag_widget, TagWidget)
# When
actions = tag_widget.bg_button.actions()
edit_action = [a for a in actions if a.text() == "Edit"][0]
edit_action.triggered.emit()
# Then
panel = tag_box_widget.edit_modal.widget
assert isinstance(panel, BuildTagPanel)
assert panel.tag.name == tag.name
assert panel.name_field.text() == tag.name

View File

@@ -0,0 +1,37 @@
import pytest
from src.core.library.alchemy.enums import FilterState
def test_filter_state_query():
# Given
query = "tag:foo"
state = FilterState(query=query)
# When
assert state.tag == "foo"
@pytest.mark.parametrize(
["attribute", "comparator"],
[
("tag", str),
("tag_id", int),
("path", str),
("name", str),
("id", int),
],
)
def test_filter_state_attrs_compare(attribute, comparator):
# When
state = FilterState(**{attribute: "2"})
# Then
# compare the attribute value
assert getattr(state, attribute) == comparator("2")
# Then
for prop in ("tag", "tag_id", "path", "name", "id"):
if prop == attribute:
continue
assert not getattr(state, prop)

View File

@@ -0,0 +1,407 @@
from pathlib import Path, PureWindowsPath
from tempfile import TemporaryDirectory
import pytest
from src.core.constants import LibraryPrefs
from src.core.library.alchemy import Entry
from src.core.library.alchemy import Library
from src.core.library.alchemy.enums import FilterState
from src.core.library.alchemy.fields import _FieldID, TextField
def test_library_bootstrap():
with TemporaryDirectory() as tmp_dir:
lib = Library()
lib.open_library(tmp_dir)
assert lib.engine
def test_library_add_file():
"""Check Entry.path handling for insert vs lookup"""
with TemporaryDirectory() as tmp_dir:
# create file in tmp_dir
file_path = Path(tmp_dir) / "bar.txt"
file_path.write_text("bar")
lib = Library()
lib.open_library(tmp_dir)
entry = Entry(
path=file_path,
folder=lib.folder,
fields=lib.default_fields,
)
assert not lib.has_path_entry(entry.path)
assert lib.add_entries([entry])
assert lib.has_path_entry(entry.path) is True
def test_create_tag(library, generate_tag):
# tag already exists
assert not library.add_tag(generate_tag("foo"))
# new tag name
tag = library.add_tag(generate_tag("xxx", id=123))
assert tag
assert tag.id == 123
tag_inc = library.add_tag(generate_tag("yyy"))
assert tag_inc.id > 1000
def test_library_search(library, generate_tag, entry_full):
assert library.entries_count == 2
tag = list(entry_full.tags)[0]
query_count, items = library.search_library(
FilterState(
tag=tag.name,
),
)
assert query_count == 1
assert len(items) == 1
entry = items[0]
assert {x.name for x in entry.tags} == {
"foo",
}
assert entry.tag_box_fields
def test_tag_search(library):
tag = library.tags[0]
assert library.search_tags(
FilterState(tag=tag.name.lower()),
)
assert library.search_tags(
FilterState(tag=tag.name.upper()),
)
assert not library.search_tags(
FilterState(tag=tag.name * 2),
)
def test_get_entry(library, entry_min):
assert entry_min.id
cnt, entries = library.search_library(FilterState(id=entry_min.id))
assert len(entries) == cnt == 1
assert entries[0].tags
def test_entries_count(library):
entries = [
Entry(path=Path(f"{x}.txt"), folder=library.folder, fields=[])
for x in range(10)
]
library.add_entries(entries)
matches, page = library.search_library(
FilterState(
page_size=5,
)
)
assert matches == 12
assert len(page) == 5
def test_add_field_to_entry(library):
# Given
entry = Entry(
folder=library.folder,
path=Path("xxx"),
fields=library.default_fields,
)
# meta tags + content tags
assert len(entry.tag_box_fields) == 2
library.add_entries([entry])
# When
library.add_entry_field_type(entry.id, field_id=_FieldID.TAGS)
# Then
entry = [x for x in library.get_entries(with_joins=True) if x.path == entry.path][0]
# meta tags and tags field present
assert len(entry.tag_box_fields) == 3
def test_add_field_tag(library, entry_full, generate_tag):
# Given
tag_name = "xxx"
tag = generate_tag(tag_name)
tag_field = entry_full.tag_box_fields[0]
# When
library.add_field_tag(entry_full, tag, tag_field.type_key)
# Then
_, entries = library.search_library(FilterState(id=entry_full.id))
tag_field = entries[0].tag_box_fields[0]
assert [x.name for x in tag_field.tags if x.name == tag_name]
def test_subtags_add(library, generate_tag):
# Given
tag = library.tags[0]
assert tag.id is not None
subtag = generate_tag("subtag1")
subtag = library.add_tag(subtag)
assert subtag.id is not None
# When
assert library.add_subtag(tag.id, subtag.id)
# Then
assert tag.id is not None
tag = library.get_tag(tag.id)
assert tag.subtag_ids
@pytest.mark.parametrize("is_exclude", [True, False])
def test_search_filter_extensions(library, is_exclude):
# Given
entries = list(library.get_entries())
assert len(entries) == 2, entries
library.set_prefs(LibraryPrefs.IS_EXCLUDE_LIST, is_exclude)
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])
# When
query_count, items = library.search_library(
FilterState(),
)
# Then
assert query_count == 1
assert len(items) == 1
entry = items[0]
assert (entry.path.suffix == ".txt") == is_exclude
def test_search_library_case_insensitive(library):
# Given
entries = list(library.get_entries(with_joins=True))
assert len(entries) == 2, entries
entry = entries[0]
tag = list(entry.tags)[0]
# When
query_count, items = library.search_library(
FilterState(tag=tag.name.upper()),
)
# Then
assert query_count == 1
assert len(items) == 1
assert items[0].id == entry.id
def test_preferences(library):
for pref in LibraryPrefs:
assert library.prefs(pref) == pref.value
def test_save_windows_path(library, generate_tag):
# pretend we are on windows and create `Path`
entry = Entry(
path=PureWindowsPath("foo\\bar.txt"),
folder=library.folder,
fields=library.default_fields,
)
tag = generate_tag("win_path")
tag_name = tag.name
library.add_entries([entry])
# library.add_tag(tag)
library.add_field_tag(entry, tag, create_field=True)
_, found = library.search_library(FilterState(tag=tag_name))
assert found
# path should be saved in posix format
assert str(found[0].path) == "foo/bar.txt"
def test_remove_entry_field(library, entry_full):
title_field = entry_full.text_fields[0]
library.remove_entry_field(title_field, [entry_full.id])
entry = next(library.get_entries(with_joins=True))
assert not entry.text_fields
def test_remove_field_entry_with_multiple_field(library, entry_full):
# Given
title_field = entry_full.text_fields[0]
# When
# add identical field
assert library.add_entry_field_type(entry_full.id, field_id=title_field.type_key)
# remove entry field
library.remove_entry_field(title_field, [entry_full.id])
# Then one field should remain
entry = next(library.get_entries(with_joins=True))
assert len(entry.text_fields) == 1
def test_update_entry_field(library, entry_full):
title_field = entry_full.text_fields[0]
library.update_entry_field(
entry_full.id,
title_field,
"new value",
)
entry = next(library.get_entries(with_joins=True))
assert entry.text_fields[0].value == "new value"
def test_update_entry_with_multiple_identical_fields(library, entry_full):
# Given
title_field = entry_full.text_fields[0]
# When
# add identical field
library.add_entry_field_type(entry_full.id, field_id=title_field.type_key)
# update one of the fields
library.update_entry_field(
entry_full.id,
title_field,
"new value",
)
# Then only one should be updated
entry = next(library.get_entries(with_joins=True))
assert entry.text_fields[0].value == ""
assert entry.text_fields[1].value == "new value"
def test_mirror_entry_fields(library, entry_full):
target_entry = Entry(
folder=library.folder,
path=Path("xxx"),
fields=[
TextField(
type_key=_FieldID.NOTES.name,
value="notes",
position=0,
)
],
)
entry_id = library.add_entries([target_entry])[0]
_, entries = library.search_library(FilterState(id=entry_id))
new_entry = entries[0]
library.mirror_entry_fields(new_entry, entry_full)
_, entries = library.search_library(FilterState(id=entry_id))
entry = entries[0]
assert len(entry.fields) == 4
assert {x.type_key for x in entry.fields} == {
_FieldID.TITLE.name,
_FieldID.NOTES.name,
_FieldID.TAGS_META.name,
_FieldID.TAGS.name,
}
def test_remove_tag_from_field(library, entry_full):
for field in entry_full.tag_box_fields:
for tag in field.tags:
removed_tag = tag.name
library.remove_tag_from_field(tag, field)
break
entry = next(library.get_entries(with_joins=True))
for field in entry.tag_box_fields:
assert removed_tag not in [tag.name for tag in field.tags]
@pytest.mark.parametrize(
["query_name", "has_result"],
[
("foo", 1), # filename substring
("bar", 1), # filename substring
("one", 0), # path, should not match
],
)
def test_search_file_name(library, query_name, has_result):
res_count, items = library.search_library(
FilterState(name=query_name),
)
assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
@pytest.mark.parametrize(
["query_name", "has_result"],
[
(1, 1),
("1", 1),
("xxx", 0),
(222, 0),
],
)
def test_search_entry_id(library, query_name, has_result):
res_count, items = library.search_library(
FilterState(id=query_name),
)
assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
def test_update_field_order(library, entry_full):
# Given
title_field = entry_full.text_fields[0]
# When add two more fields
library.add_entry_field_type(
entry_full.id, field_id=title_field.type_key, value="first"
)
library.add_entry_field_type(
entry_full.id, field_id=title_field.type_key, value="second"
)
# remove the one on first position
assert title_field.position == 0
library.remove_entry_field(title_field, [entry_full.id])
# recalculate the positions
library.update_field_position(
type(title_field),
title_field.type_key,
entry_full.id,
)
# Then
entry = next(library.get_entries(with_joins=True))
assert entry.text_fields[0].position == 0
assert entry.text_fields[0].value == "first"
assert entry.text_fields[1].position == 1
assert entry.text_fields[1].value == "second"