From 033bff86570f65e54aee9c5c01f20a20bd79f99f Mon Sep 17 00:00:00 2001 From: Enting Chen Date: Tue, 15 Jul 2025 16:01:48 +0100 Subject: [PATCH 1/2] index --- src/retrieval/graph_retriever/README.md | 42 ++++++ .../grag/embed_models/__init__.py | 2 + .../graph_retriever/grag/embed_models/base.py | 31 ++++ .../grag/embed_models/sbert.py | 54 +++++++ .../graph_retriever/grag/index/__init__.py | 2 + .../graph_retriever/grag/index/es.py | 136 ++++++++++++++++++ .../graph_retriever/requirements.txt | 20 +++ 7 files changed, 287 insertions(+) create mode 100644 src/retrieval/graph_retriever/README.md create mode 100644 src/retrieval/graph_retriever/grag/embed_models/__init__.py create mode 100644 src/retrieval/graph_retriever/grag/embed_models/base.py create mode 100644 src/retrieval/graph_retriever/grag/embed_models/sbert.py create mode 100644 src/retrieval/graph_retriever/grag/index/__init__.py create mode 100644 src/retrieval/graph_retriever/grag/index/es.py create mode 100644 src/retrieval/graph_retriever/requirements.txt diff --git a/src/retrieval/graph_retriever/README.md b/src/retrieval/graph_retriever/README.md new file mode 100644 index 0000000..1147a67 --- /dev/null +++ b/src/retrieval/graph_retriever/README.md @@ -0,0 +1,42 @@ +# Graph Based Retrieval + +## 1. Environment Setup + +Python environment: +``` +conda create -n grag python=3.12.11 +conda activate grag +cd graph-based-retrieval +pip install -r requirements.txt +``` + +## 2. Run Triple Extraction + +## 3. Indexing + +#### 3.1 Text Indexer +The TextIndexer class processes and builds a text-based index. It splits documents into multiple text chunks (TextNode), uses the SBERT model to generate embeddings, and stores the results in ElasticSearch. + +Configuration +- embed_model: model used to create embedding for each chunk +- batch_size: Controls the number of documents processed in each batch +- es_url, es_index: ElasticSearch index to store text chunks +- data_dir: directory of the jsonl file of documents + +#### 3.2 Triple Indexer +The TripleIndexer class processes and builds an index based on triples. It stores extracted triple data into Elasticsearch for easy querying. + +Configuration +- batch_size: Controls the number of triples processed in each batch +- es_url, es_index: Elasticsearch index to store text triples +- text_es_url, text_es_index: Elasticsearch index of text chunks +- data_dir: Patch to the directory containing triple data +- batch_size: Controls the number of triples processed in each batch + +## 4. Run Retrieval Experiments + +For example: + +```sh +python -m src.retrieval.local_search +``` \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/embed_models/__init__.py b/src/retrieval/graph_retriever/grag/embed_models/__init__.py new file mode 100644 index 0000000..0d2d711 --- /dev/null +++ b/src/retrieval/graph_retriever/grag/embed_models/__init__.py @@ -0,0 +1,2 @@ +from src.retrieval.graph_retriever.grag.embed_models.base import EmbedModel +from src.retrieval.graph_retriever.grag.embed_models.sbert import SBERT \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/embed_models/base.py b/src/retrieval/graph_retriever/grag/embed_models/base.py new file mode 100644 index 0000000..b6f070f --- /dev/null +++ b/src/retrieval/graph_retriever/grag/embed_models/base.py @@ -0,0 +1,31 @@ +from abc import ABCMeta, abstractmethod +from typing import Any + +import torch + +class EmbedModel(metaclass=ABCMeta): + @abstractmethod + def embed_docs( + self, + texts: list[str], + batch_size: int | None = None, + **kwargs: Any, + ) -> torch.Tensor: + """Embed documents.""" + pass + + @abstractmethod + def embed_query(self, text: str, **kwargs: Any) -> torch.Tensor: + """Embed a single query.""" + pass + + def embed_queries(self, texts: list[str], **kwargs: Any) -> torch.Tensor: + """Embed queries. + + Note: + Overwrite this method if batch computing should be supported. + """ + return torch.stack([self.embed_query(x, **kwargs) for x in texts]) + + def get_embedding_dimension(self) -> int: + return self.embed_query("X").shape[-1] diff --git a/src/retrieval/graph_retriever/grag/embed_models/sbert.py b/src/retrieval/graph_retriever/grag/embed_models/sbert.py new file mode 100644 index 0000000..fff6152 --- /dev/null +++ b/src/retrieval/graph_retriever/grag/embed_models/sbert.py @@ -0,0 +1,54 @@ +from typing import Any + +import torch +from sentence_transformers import SentenceTransformer + +from src.retrieval.graph_retriever.grag.embed_models.base import EmbedModel +from src.retrieval.graph_retriever.grag.embed_models.utils import load_sentence_transformer + + +class SBERT(EmbedModel): + def __int__( + self, + model: str | SentenceTransformer, + device: str | None = None, + **model_args: Any, + ) -> None: + self._model = ( + model + if isinstance(model, SentenceTransformer) + else load_sentence_transformer(model, device=device, **model_args) + ) + + @property + def model(self) -> SentenceTransformer: + return self._model + + def embed_docs( + self, + texts: list[str], + batch_size: int = 32, + **kwargs: Any, + ) -> torch.Tensor: + return self._model.encode( + texts, + batch_size=batch_size, + convert_to_tensor=True, + **kwargs, + ) + + def embed_query( + self, + texts: list[str], + batch_size: int = 32, + **kwargs: Any, + ) -> torch.Tensor: + return self.embed_docs(texts, batch_size=batch_size, **kwargs) + + def get_embedding_dimension(self) -> int: + dim = self.model.get_sentence_embedding_dimension() + if not isinstance(dim, int): + raise RuntimeError(f"{dim=}; expect int") + + return dim + diff --git a/src/retrieval/graph_retriever/grag/index/__init__.py b/src/retrieval/graph_retriever/grag/index/__init__.py new file mode 100644 index 0000000..b12a5bc --- /dev/null +++ b/src/retrieval/graph_retriever/grag/index/__init__.py @@ -0,0 +1,2 @@ +from src.retrieval.graph_retriever.grag.index.es import BaseESWrapper, BaseIndexer +from src.retrieval.graph_retriever.grag.index.chunk import TextSplitter, LlamaindexSplitter \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/index/es.py b/src/retrieval/graph_retriever/grag/index/es.py new file mode 100644 index 0000000..e410f2b --- /dev/null +++ b/src/retrieval/graph_retriever/grag/index/es.py @@ -0,0 +1,136 @@ +import asyncio +import itertools +from typing import Any, Literal +from collections.abc import Iterable +from abc import ABCMeta, abstractmethod + +from tqdm import tqdm +from llama_index.core.schema import TextNode +from llama_index.vector_stores.elasticsearch import ElasticsearchStore +from elasticsearch import AsyncElasticsearch + +from src.retrieval.graph_retriever.grag.embed_models import EmbedModel +from src.retrieval.graph_retriever.grag.index.chunk import TextSplitter + + +class BaseESWrapper: + """Base class that wraps Elasticsearch and Llamaindex.""" + + def __init__( + self, + es_index: str, + es_url: str, + es_client: AsyncElasticsearch | None = None, + ) -> None: + self.es_index = es_index + self.es_url = es_url + self.es_client = es_index or AsyncElasticsearch(self.es_url, timeout=600) + self._es = ElasticsearchStore(index_name=self.es_index, es_client=self.es_client) + + def __del__(self) -> None: + # to suppress warning: "Unclosed client session" + asyncio.get_event_loop().run_until_complete(self.es_client.close()) + + @property + def es(self) -> ElasticsearchStore: + return self._es + + +class BaseIndexer(BaseESWrapper, metaclass=ABCMeta): + """Abstract base class for indexing. + + Notes: + Need to implement data-specific preprocessing and define mappings for metadata. + + """ + + def __int__( + self, + es_index: str, + es_url: str, + embed_model: EmbedModel | None = None, + splitter: TextSplitter | None = None, + es_client: AsyncElasticsearch | None = None, + ) -> None: + super().__int__( + es_index=es_index, + es_url=es_url, + es_client=es_client, + ) + + if embed_model and not isinstance(embed_model, EmbedModel): + raise TypeError(f"{type(embed_model)=}") + + self.embed_model = embed_model + self.splitter = splitter + + @abstractmethod + def preprocess(self, doc: dict, spliiter: TextSplitter) -> list[TextNode]: + """Preprocess a document and return a list of chunks.""" + pass + + @abstractmethod + def get_metadata_mappings(self, **kwargs: Any) -> dict: + """Return mappings for metadata. + + Examples: + {"properties": {"title": {"type": "text"}}} + + """ + pass + + async def create_es_index(self, distance_strategy: str = "cosine", analyzer: str | None = None) -> None: + """Create Elasticsearch index. + + Overwrite this method if needed. + + """ + client: AsyncElasticsearch = self.es.client + + metadata_mappings = self.get_metadata_mappings(analyzer=analyzer)["properties"] + # See `llama_index.vector_stores.elasticsearch.ElasticsearchStore` + # See also `llama_index.core.vector_stores.utitls.node_to_metadata_dict` + if "doc_id" in metadata_mappings or "ref_doc_id" in metadata_mappings or "document_id" in metadata_mappings: + raise ValueError( + f"`doc_id`, `ref_doc_id`, `document_id` are occupied by LlamaIndex. " + "We should use other fielf names to avoid potential conflictss and/or unexpected behaviour." + ) + + await client.indices.create( + index=self.es.index_name, + mappings={ + "properties": { + self.es.vector_field: { + "type": "dense_vector", + "dims": self.embed_model.get_embedding_dimension(), + "index": True, + "similarity": distance_strategy, + }, + self.es.text_field: ({"type": "text", "analyzer": analyzer} if analyzer else {"type": "text"}), + "metadata" : { + "properties": { + # fields reserved by llama_index; these fields will be overwritten. + # See `llama_index.vector_stores.elasticsearch.ElasticsearchStore` + # See also `llama_index.core.vector_stores.utils.node_to_metadata_dict` + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + **metadata_mappings, + } + }, + }, + }, + ) + + + def embed_nodes(self, nodes: list[TextNode], batch_size: int = 32) -> list[TextNode]: + if self.embed_model is None: + return nodes + + texts = [node.text for node in nodes] + embeddings = self.embed_model.embed_docs(texts, batch_size=batch_size).tolist() + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + + return nodes + diff --git a/src/retrieval/graph_retriever/requirements.txt b/src/retrieval/graph_retriever/requirements.txt new file mode 100644 index 0000000..1efd71e --- /dev/null +++ b/src/retrieval/graph_retriever/requirements.txt @@ -0,0 +1,20 @@ +elasticsearch==8.17.1 +sentence-transformers==3.4.1 +torch==2.7.0 +llama-index==0.12.36 +llama-index-vector-stores-elasticsearch==0.4.3 +tqdm +pytest +loguru +rapidfuzz +diskcache +jsonnet +more-itertools +pydantic +gunicorn +requests +flask +flask[async] +flask-cors +huggingface_hub=0.25.2 +ftfy \ No newline at end of file -- Gitee From 5ba19abd52a77e90ea2f9bc0cfe88405198c650a Mon Sep 17 00:00:00 2001 From: Enting Chen Date: Tue, 15 Jul 2025 17:34:20 +0100 Subject: [PATCH 2/2] index --- .../grag/index/chunk/__init__.py | 2 + .../graph_retriever/grag/index/chunk/base.py | 10 +++ .../grag/index/chunk/llamaindex.py | 57 ++++++++++++++ .../graph_retriever/grag/index/es.py | 77 +++++++++++++++++++ .../grag/pipeline/extract_triples.py | 59 ++++++++++++++ .../graph_retriever/grag/pipeline/index.py | 0 .../grag/pipeline/index_triples.py | 54 +++++++++++++ .../graph_retriever/grag/pipeline/utils.py | 0 8 files changed, 259 insertions(+) create mode 100644 src/retrieval/graph_retriever/grag/index/chunk/__init__.py create mode 100644 src/retrieval/graph_retriever/grag/index/chunk/base.py create mode 100644 src/retrieval/graph_retriever/grag/index/chunk/llamaindex.py create mode 100644 src/retrieval/graph_retriever/grag/pipeline/extract_triples.py create mode 100644 src/retrieval/graph_retriever/grag/pipeline/index.py create mode 100644 src/retrieval/graph_retriever/grag/pipeline/index_triples.py create mode 100644 src/retrieval/graph_retriever/grag/pipeline/utils.py diff --git a/src/retrieval/graph_retriever/grag/index/chunk/__init__.py b/src/retrieval/graph_retriever/grag/index/chunk/__init__.py new file mode 100644 index 0000000..1a268c2 --- /dev/null +++ b/src/retrieval/graph_retriever/grag/index/chunk/__init__.py @@ -0,0 +1,2 @@ +from src.retrieval.graph_retriever.grag.index.chunk.base import TextSplitter +from src.retrieval.graph_retriever.grag.index.chunk.llamaindex import LlamaindexSplitter \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/index/chunk/base.py b/src/retrieval/graph_retriever/grag/index/chunk/base.py new file mode 100644 index 0000000..7ad3ba2 --- /dev/null +++ b/src/retrieval/graph_retriever/grag/index/chunk/base.py @@ -0,0 +1,10 @@ +from abc import ABCMeta, abstractmethod + +from llama_index.core.schema import TextNode + + +class TextSplitter(metaclass=ABCMeta): + + @abstractmethod + def split(self, text: TextNode) -> list[TextNode]: + pass \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/index/chunk/llamaindex.py b/src/retrieval/graph_retriever/grag/index/chunk/llamaindex.py new file mode 100644 index 0000000..1337b13 --- /dev/null +++ b/src/retrieval/graph_retriever/grag/index/chunk/llamaindex.py @@ -0,0 +1,57 @@ +from llama_index.core.schema import TextNode +from llama_index.core.node_parser import SentenceSplitter +from transformers import PreTrainedTokenizerBase + +from src.retrieval.graph_retriever.grag.index.chunk.base import TextSplitter + + +class LlamaindexSplitter(TextSplitter): + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + splitter_config: dict | None = None, + ) -> None: + """Wrapper of llamaindex's splitter. + + Args: + tokenizer (PreTrainedTokenizerBase): Tokenizer. + chunk_size (int | None, optional): Chunk size to split documents into passages. Defaults to None. + Note: this is based on tokens produced by the tokenizer of embedding model, + If None, set to the maximum sequence length of the embedding model. + chunk_overlap (int | None, optional): Window size for passage overlap. Defaults to None. + If None, set to `chunk_size // 5`. + splitter_config (dict, optional): other arguments to SentenceSplitter. Defaults to None. + + """ + super().__init__() + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise TypeError(f"{type(tokenizer)=}") + + self._tokenizer = tokenizer + + if not isinstance(splitter_config, dict): + splitter_config = { + "paragraph_separator": "\n", + } + + chunk_size = chunk_size or tokenizer.max_len_single_sentence + chunk_size = min(chunk_size, tokenizer.max_len_single_sentence) + + self._splitter = SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap or chunk_size // 5, + tokenizer=self._tokenizer.tokenize, + **splitter_config, + ) + + def split(self, doc: TextNode) -> list[TextNode]: + # Note: we don't want to consider the length of metadata for chunking + if not doc.excluded_embed_metadata_keys: + doc.excluded_embed_metadata_keys = list(doc.metadata.keys()) + + if not doc.excluded_llm_metadata_keys: + doc.excluded_llm_metadata_keys = list(doc.metadata.keys()) + + return self._splitter.get_nodes_from_documents([doc]) \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/index/es.py b/src/retrieval/graph_retriever/grag/index/es.py index e410f2b..f32e6b5 100644 --- a/src/retrieval/graph_retriever/grag/index/es.py +++ b/src/retrieval/graph_retriever/grag/index/es.py @@ -134,3 +134,80 @@ class BaseIndexer(BaseESWrapper, metaclass=ABCMeta): return nodes + def build_index( + self, + dataset: Iterable[dict], + batch_size: int = 128, + distance_strategy: Literal["cosine", "dot_product", "l2_norm"] = "cosine", + es_analyzer: str | None = None, + *, + debug: bool = False, + ) -> None: + """Build an Elasticsearch index for the input `dataset`. + + Note: + 1. Adding data to an exsiting index is not allowed. + 2. Manually delete an exisiting index if needed. + + Args: + dataset (Iterable[dict]): Dataset of documents. + batch_size (int, optional): Batch size for embedding passages. Defaults to 128. + distance_strategy (str): Similarity metric supported by Elasticsearch. Defaults to cosine. + es_analyzer (str, optional): Elasticsearch tokenizer for text field, Defaults to None. + E.g., use "smartcn" for Chinese text. + See: https://www.elastic.co/guide/en/elasticsearch/reference/current/specify-analyzer.html + debug (bool, optional): Debug mode. Defaults to False. + If True, index the first 100 documents only. + + Raises: + RuntimeError: If the index exists. + """ + if self.embed_model is None: + raise NotImplementedError("build both full-text and vector index by default") + + asyncio.run( + self._build_index( + dataset, + batch_size=batch_size, + distance_strategy=distance_strategy, + es_analyzer=es_analyzer, + debug=debug, + ) + ) + + async def _build_index( + self, + dataset: Iterable[dict], + batch_size: int = 128, + distance_strategy: str = "cosine", + es_analyzer: str | None = None, + *, + debug: bool = False, + ) -> None: + client: AsyncElasticsearch = self.es.client + if await client.indices.exists(index=self.es.index_name): + raise RuntimeError(f"index {self.es.index_name} exists") + + await self.create_es_index(distance_strategy=distance_strategy, analyzer=es_analyzer) + + total = None + datastream = dataset + if debug: + total = 100 + datastream = itertools.islice(dataset, total) + + cache = [] + for doc in tqdm( + datastream, + desc="indexing documents", + total=total, + ): + cache.extend(self.preprocess(doc, self.splitter)) + + if len(cache) > batch_size: + nodes = self.embed_nodes(cache[:batch_size], batch_size) + cache = cache[batch_size:] + await self.es.async_add(nodes=nodes, create_index_if_not_exists=False) + + if cache: + await self.es.async_add(nodes=self.embed_nodes(cache, batch_size), create_index_if_not_exists=False) diff --git a/src/retrieval/graph_retriever/grag/pipeline/extract_triples.py b/src/retrieval/graph_retriever/grag/pipeline/extract_triples.py new file mode 100644 index 0000000..8c0458c --- /dev/null +++ b/src/retrieval/graph_retriever/grag/pipeline/extract_triples.py @@ -0,0 +1,59 @@ +import json +import os + +import asyncio +from tqdm import tqdm +from elasticsearch import Elasticsearch + +from src.llm.llm_wrapper import LLMWrapper +from src.retrieval.graph_retriever.grag.utils import load_jsonl, DATA_DIR +from src.retrieval.graph_retriever.grag.utils.es import iter_index +from src.retrieval.graph_retriever.grag.reranker.llm_openie import PROMPT as PROMPT_TEMPLATE, LLMOpenIE + + +ES_HOST = os.getenv("CHUNK_ES_URL") +CHUNK_FILE_PATH = DATA_DIR / "triple_extraction" / "example_chunks.jsonl" + +async def process_chunk(chunk, save_path): + prompt = PROMPT_TEMPLATE.format(passage=chunk["content"], wiki_title=chunk["title"]) + completion = await LLMWrapper("basic").ainvoke(prompt) + _, triples_list = LLMOpenIE.match_entities_triples(completion.content) + buffer = {chunk["content"]: triples_list} + + with open(save_path, "a") as f: + f.write(json.dumps(buffer, ensure_ascii=False) + "\n") + + +async def process_data(data, save_path, start_idx=0): + tasks = [] + for chunk in tqdm(data[start_idx:], desc="Processing chunks"): + task = asyncio.create_task(process_chunk(chunk, save_path)) + tasks.append(task) + + await asyncio.gather(*tasks) + + +def load_index() -> list[dict]: + es = Elasticsearch(ES_HOST) + with open(CHUNK_FILE_PATH, "w+") as f: + for batch in tqdm(iter_index(es, os.getenv("CHUNK_ES_INDEX"),), desc="downloading chunks..."): + for item in batch: + content = item["_source"]["content"] + title = item["_source"]["metadata"]["title"] + f.write(json.dumps({"title": title, "content": content}, ensure_ascii=False)) + f.write("\n") + + +def main(): + load_index() + asyncio.run( + process_data( + load_jsonl(CHUNK_FILE_PATH), + DATA_DIR / "triple_extraction" / "chunk2triple_completions.jsonl", + start_idx=0, + ) + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/pipeline/index.py b/src/retrieval/graph_retriever/grag/pipeline/index.py new file mode 100644 index 0000000..e69de29 diff --git a/src/retrieval/graph_retriever/grag/pipeline/index_triples.py b/src/retrieval/graph_retriever/grag/pipeline/index_triples.py new file mode 100644 index 0000000..d5eb6bc --- /dev/null +++ b/src/retrieval/graph_retriever/grag/pipeline/index_triples.py @@ -0,0 +1,54 @@ +import json +import os +from typing import Any + +from llama_index.core.schema import TextNode +from elasticsearch import Elasticsearch + +from src.retrieval.graph_retriever.grag.index import BaseIndexer +from src.retrieval.graph_retriever.grag.embed_models import SBERT +from src.retrieval.graph_retriever.grag.index.chunk import TextSplitter +from src.retrieval.graph_retriever.grag.utils import DATA_DIR +from src.retrieval.graph_retriever.grag.pipeline.utils import prepare_triples + + +class TripleIndexer(BaseIndexer): + def preprocess(self, doc: dict, splitter: TextSplitter) -> list[TextNode]: + return [TextNode(text=doc["text"], metadata=doc["metadata"])] + + def get_metadata_mappings(self, **kwargs: Any) -> dict: + return { + "properties": { + "chunk_id": {"type": "keyword"}, + "triple": {"type": "text", "index": False}, + } + } + + +def main(): + embed_model = SBERT(os.getenv("EMBED_MODEL")) + es_url = os.getenv("TRIPLES_ES_URL") + es_index = os.getenv("TRIPLES_ES_INDEX") + text_es_url = os.getenv("CHUNK_ES_URL") + text_es_index = os.getenv("CHUNK_ES_INDEX") + data_dir = DATA_DIR / "triple_extraction" / "chunk2triple_completions.jsonl" + batch_size = 1024 + + es = TripleIndexer( + es_index=es_index, + es_url=es_url, + embed_model=embed_model, + ) + + chunk2triples = {} + with open(data_dir, "r", encoding="utf-8") as f: + for line in f: + chunk2triples.update(json.loads(line)) + + datastream = prepare_triples(Elasticsearch(text_es_url), chunk2triples, text_es_index) + + es.build_index(datastream, batch_size=batch_size, debug=False) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/retrieval/graph_retriever/grag/pipeline/utils.py b/src/retrieval/graph_retriever/grag/pipeline/utils.py new file mode 100644 index 0000000..e69de29 -- Gitee