Source code for ahvn.klstore.vdb_store

from __future__ import annotations

__all__ = [
    "VectorKLStore",
]

from typing import Any, Generator, Iterable, List, Optional, Callable, TYPE_CHECKING

if TYPE_CHECKING:
    from llama_index.core.schema import TextNode

from ..adapter.vdb import VdbUKFAdapter
from ..ukf.base import BaseUKF
from ..utils.basic.config_utils import HEAVEN_CM
from ..utils.basic.log_utils import get_logger
from ..utils.basic.misc_utils import unique
from ..utils.basic.progress_utils import Progress
from ..utils.vdb.base import VectorDatabase
from .base import BaseKLStore

logger = get_logger(__name__)


[docs] class VectorKLStore(BaseKLStore): """\ Vector database backed KL store using the VDB adapter. Minimal implementation that maps vector database records to BaseUKF objects. """
[docs] def __init__(self, collection: Optional[str] = None, name: Optional[str] = None, condition: Optional[Callable] = None, *args, **kwargs): """\ Initialize the vector database KL store. Args: collection: Vector database collection or table name. name: Name of the KLStore instance. If None, defaults to the collection name or "default". condition: Optional upsert/insert condition to apply to the KLStore. KLs that do not satisfy the condition will be ignored. If None, all KLs are accepted. *args: Additional positional arguments for BaseKLStore. **kwargs: Additional keyword arguments for adapter or vector database configuration. """ super().__init__(name=name or collection, condition=condition, *args, **kwargs) provider = kwargs.get("provider") or HEAVEN_CM.get("vdb.default_provider", "lancedb") encoder = kwargs.get("encoder") embedder = kwargs.get("embedder") include = kwargs.get("include") exclude = kwargs.get("exclude") collection = collection or kwargs.get("collection") or HEAVEN_CM.get(f"vdb.providers.{provider}.collection") connection_args = { k: v for k, v in kwargs.items() if k not in { "collection", "provider", "encoder", "embedder", "include", "exclude", } } self.vdb = VectorDatabase(collection=collection or self.name, provider=provider, encoder=encoder, embedder=embedder, **connection_args) adapter_kwargs = { "backend": self.vdb.backend, "name": self.name, "include": include, "exclude": exclude, } self.adapter = VdbUKFAdapter(**adapter_kwargs) self._init()
def _init(self): self.vdb.connect() # Insert a dummy record to specify the schema dummy = BaseUKF(name="__dummy__", type="dummy") self.vdb.vdb.add(self._batch_convert([dummy])) self.vdb.flush() # Remove the dummy node self._remove(dummy.id) def _has(self, key: int) -> bool: ukf_id = self.adapter.parse_id(key) entities = self.vdb.vdb.get_nodes(node_ids=[ukf_id]) if len(entities) > 1: raise ValueError(f"Multiple entities found for key {key} (id: {ukf_id})") return len(entities) == 1 def _get(self, key: int, default: Any = ...) -> Optional[BaseUKF]: ukf_id = self.adapter.parse_id(key) entities = self.vdb.vdb.get_nodes(node_ids=[ukf_id]) if len(entities) > 1: raise ValueError(f"Multiple entities found for key {key} (id: {ukf_id})") if len(entities) < 1: return default return self.adapter.to_ukf(entity=entities[0]) def _batch_convert(self, kls: Iterable[BaseUKF]) -> List[TextNode]: nodes = list() keys_embeddings = self.vdb.batch_k_encode_embed(kls) for kl, (key, embedding) in zip(kls, keys_embeddings): nodes.append(self.adapter.from_ukf(kl=kl, key=key, embedding=embedding)) return nodes def _upsert(self, kl: BaseUKF, **kwargs): ukf_id = self.adapter.parse_id(kl.id) self.vdb.vdb.delete_nodes([ukf_id]) self.vdb.vdb.add(self._batch_convert([kl])) def _batch_upsert(self, kls: list[BaseUKF], progress: Progress = None, **kwargs): kls = unique(kls, key=lambda kl: kl.id) # Keeping only the first occurrence of each ID in case of duplicates if not kls: return ukf_ids = [self.adapter.parse_id(kl.id) for kl in kls] self.vdb.vdb.delete_nodes(ukf_ids) self.vdb.vdb.add(self._batch_convert(kls)) if progress is not None: progress.update(len(kls)) def _batch_insert(self, kls: list[BaseUKF], progress: Progress = None, **kwargs): kls = unique(kls, key=lambda kl: kl.id) # Keeping only the first occurrence of each ID in case of duplicates if not kls: return ukf_ids = [self.adapter.parse_id(kl.id) for kl in kls] existing = set(node.node_id for node in self.vdb.vdb.get_nodes(node_ids=ukf_ids)) delta = [kl for kl, ukf_id in zip(kls, ukf_ids) if ukf_id not in existing] if not delta: return self.vdb.vdb.add(self._batch_convert(delta)) if progress is not None: progress.update(len(delta)) def _remove(self, key: int, **kwargs) -> bool: if key not in self: return False self.vdb.vdb.delete_nodes([self.adapter.parse_id(key)]) return True def _batch_remove(self, keys: Iterable[int], progress: Progress = None, **kwargs): keys = unique(keys) # Keeping only unique keys if not keys: return ukf_ids = [self.adapter.parse_id(key) for key in keys if key in self] if not ukf_ids: return self.vdb.vdb.delete_nodes(ukf_ids) if progress is not None: progress.update(len(ukf_ids)) def __len__(self) -> int: return len(self.vdb._get_all_nodes()) def _itervalues(self) -> Generator[BaseUKF, None, None]: for node in self.vdb._get_all_nodes(): yield self.adapter.to_ukf(entity=node) def _clear(self): self.vdb.clear()
[docs] def close(self): if self.vdb is not None: self.vdb.close() self.vdb = None