{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13ZsKhrZfklw"
      },
      "source": [
        "# Improving PostgreSQL Keyword Search to Avoid Empty Results\n",
        "\n",
        "*Notebook by [Mayank Laddha](https://www.linkedin.com/in/mayankladdha31/)*\n",
        "\n",
        "As noted in the Haystack documentation for [PgvectorKeywordRetriever](https://docs.haystack.deepset.ai/docs/pgvectorkeywordretriever), this component, unlike others such as `ElasticsearchBM25Retriever`, doesn’t apply fuzzy search by default. As a result, queries need to be crafted carefully to avoid returning empty results.\n",
        "\n",
        "In this notebook, you’ll extend [PgvectorDocumentStore](https://docs.haystack.deepset.ai/docs/pgvectordocumentstore#/) to make it more forgiving and flexible. You’ll learn how to subclass it to use PostgreSQL’s `websearch_to_tsquery` and how to leverage NLTK to extract keywords and transform user queries.\n",
        "\n",
        "Haystack’s modular design makes it easy to tweak or enhance components when results don’t meet expectations and this notebook will show exactly how to do that."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tHZV5A7Jfkly"
      },
      "source": [
        "## Setting up the Development Environment\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBtQpazHyr1w"
      },
      "source": [
        "Install required dependencies and set up PostgreSQL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dTEssDnnywZ9",
        "outputId": "4e38dcfd-0d3a-40c0-fd3d-6f77cd238538"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "pip install -q haystack-ai pgvector-haystack psycopg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ySQFKJsufklz",
        "outputId": "a8be5495-c178-4b9e-ba0a-9ad4129cfa3e"
      },
      "outputs": [],
      "source": [
        "#The output of the installation is not displayed when %%capture is used at the start of the cell\n",
        "%%capture\n",
        "# Install PostgreSQL and pgvector (version-agnostic)\n",
        "!sudo apt-get -y -qq update\n",
        "!sudo apt-get -y -qq install postgresql postgresql-server-dev-all git make gcc\n",
        "\n",
        "# Build and install pgvector manually (works for any PostgreSQL version)\n",
        "!git clone --quiet https://github.com/pgvector/pgvector.git\n",
        "!cd pgvector && make && sudo make install\n",
        "\n",
        "# Start PostgreSQL service\n",
        "!sudo service postgresql start\n",
        "\n",
        "# Set password for default user\n",
        "!sudo -u postgres psql -c \"ALTER USER postgres PASSWORD 'postgres';\"\n",
        "\n",
        "# Create database only if it doesn't exist\n",
        "!sudo -u postgres psql -tc \"SELECT 1 FROM pg_database WHERE datname = 'sampledb'\" | grep -q 1 || \\\n",
        "sudo -u postgres psql -c \"CREATE DATABASE sampledb;\"\n",
        "\n",
        "# Enable pgvector extension in sampledb\n",
        "!sudo -u postgres psql -d sampledb -c \"CREATE EXTENSION IF NOT EXISTS vector;\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3YNQ1_MllrQX"
      },
      "source": [
        "Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. This is needed for Haystack."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "RAuddSHHl9dg"
      },
      "outputs": [],
      "source": [
        "# set connection\n",
        "import os\n",
        "os.environ[\"PG_CONN_STR\"] = \"postgresql://postgres:postgres@localhost:5432/sampledb\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgYlL9qXfklz"
      },
      "source": [
        "## Subclassing `PgvectorDocumentStore` to Enable Websearch-Style Queries"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "54XP-n3Rw3-l"
      },
      "source": [
        "Why not `plainto_tsquery`? Why `websearch_to_tsquery`?\n",
        "\n",
        "`plainto_tsquery` transforms the unformatted text querytext to a `tsquery` value. The text is parsed and normalized much as for to_tsvector, then the & (AND) `tsquery` operator is inserted between surviving words. so all your keywords need to be present in the document.\n",
        "\n",
        "`websearch_to_tsquery` creates a `tsquery` value from querytext using an alternative syntax in which simple unformatted text is a valid query. Unlike `plainto_tsquery` and `phraseto_tsquery`, it also recognizes certain operators. Moreover, this function will never raise syntax errors, which makes it possible to use raw user-supplied input for search. The following syntax is supported:\n",
        "\n",
        "**unquoted text**: text not inside quote marks will be converted to terms separated by & operators, as if processed by `plainto_tsquery`.\n",
        "\n",
        "**\"quoted text\"**: text inside quote marks will be converted to terms separated by <-> operators, as if processed by `phraseto_tsquery`.\n",
        "\n",
        "**OR**: the word “or” will be converted to the | operator.\n",
        "\n",
        "**-**: a dash will be converted to the ! operator.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "QrLmhrf6mzCr"
      },
      "outputs": [],
      "source": [
        "from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore\n",
        "from psycopg.sql import SQL, Composed, Identifier, Literal as SQLLiteral\n",
        "from typing import Dict, Any, Optional, Tuple, Union\n",
        "\n",
        "class CustomPgvectorDocumentStore(PgvectorDocumentStore):\n",
        "    def _build_keyword_retrieval_query(\n",
        "        self, query: str, top_k: int, filters: Optional[Dict[str, Any]] = None\n",
        "    ) -> Tuple[Composed, tuple]:\n",
        "\n",
        "        # Replace plainto_tsquery with websearch_to_tsquery\n",
        "        KEYWORD_QUERY_CUSTOM = \"\"\"\n",
        "        SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score\n",
        "        FROM {schema_name}.{table_name}, websearch_to_tsquery({language}, %s) query\n",
        "        WHERE to_tsvector({language}, content) @@ query\n",
        "        \"\"\"\n",
        "        sql_select = SQL(KEYWORD_QUERY_CUSTOM).format(\n",
        "            schema_name=Identifier(self.schema_name),\n",
        "            table_name=Identifier(self.table_name),\n",
        "            language=SQLLiteral(self.language),\n",
        "            query=SQLLiteral(query),\n",
        "        )\n",
        "\n",
        "        where_params = ()\n",
        "        sql_where_clause: Union[Composed, SQL] = SQL(\"\")\n",
        "        if filters:\n",
        "            sql_where_clause, where_params = self._convert_filters_to_where_clause_and_params(\n",
        "                filters=filters, operator=\"AND\"\n",
        "            )\n",
        "\n",
        "        sql_sort = SQL(\" ORDER BY score DESC LIMIT {top_k}\").format(top_k=SQLLiteral(top_k))\n",
        "        sql_query = sql_select + sql_where_clause + sql_sort\n",
        "\n",
        "        return sql_query, where_params\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcWJu3pMfklz"
      },
      "source": [
        "## Detect Keywords with NLTK"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "osBpnjr6xmIz"
      },
      "source": [
        "Detecting keywords make sure we use only the relevant words. So, even if you decide to use the default implementation with `plainto_tsquery`, which uses AND operator, you stil have better chances of not getting zero results."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ioTu4XlCoy_R"
      },
      "source": [
        "Download required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gaYeFjEOo1jk",
        "outputId": "1c1cae98-1d9b-4d9b-f36f-2797223088e1"
      },
      "outputs": [],
      "source": [
        "import nltk\n",
        "nltk.download('punkt_tab')\n",
        "nltk.download('averaged_perceptron_tagger_eng')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AMluHSWyo-nu"
      },
      "source": [
        "Simple keyword detector"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "FSpsnpd0oozv"
      },
      "outputs": [],
      "source": [
        "from nltk import word_tokenize, pos_tag\n",
        "def extract_keywords(query: str):\n",
        "    tokens = word_tokenize(query)\n",
        "    nouns = [word for word, pos in pos_tag(tokens) if pos.startswith(\"NN\")]\n",
        "    return nouns[:5]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IvAzR5czp6DX"
      },
      "source": [
        "## Test the Improved Implementation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hMVwCos0qC1Y",
        "outputId": "dc16d650-fb3d-4e55-f9aa-1bd19ac086d1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "transformed query Jean OR Mayank OR Alex OR Paris\n",
            "result {'documents': [Document(id=1, content: 'My name is Jean and I live in Paris.', score: 0.2)]}\n"
          ]
        }
      ],
      "source": [
        "from haystack import Document\n",
        "from haystack_integrations.components.retrievers.pgvector import PgvectorKeywordRetriever\n",
        "import psycopg\n",
        "\n",
        "#use our custom store instead of PgvectorDocumentStore, that's it\n",
        "document_store = CustomPgvectorDocumentStore()\n",
        "\n",
        "#rest of the flow/pipeline will remain the same\n",
        "retriever = PgvectorKeywordRetriever(document_store=document_store,top_k = 1)\n",
        "\n",
        "document_store.write_documents([\n",
        "    Document(id = \"1\" ,content=\"My name is Jean and I live in Paris.\"),\n",
        "    Document(id = \"2\", content=\"My name is Mark and I live in Berlin.\"),\n",
        "    Document(id = \"3\", content=\"My name is Giorgio and I live in Rome.\")\n",
        "])\n",
        "\n",
        "query = \"Do you think Jean, Mayank and Alex live in Paris?\"\n",
        "keywords = extract_keywords(query)\n",
        "transformed_query =  \" OR \".join(keywords)\n",
        "print(\"transformed query\",transformed_query)\n",
        "res = retriever.run(query=transformed_query)\n",
        "print(\"result\",res)\n",
        "\n",
        "#just delete the data\n",
        "schema_name = \"public\"\n",
        "vec_table_name = \"haystack_documents\"\n",
        "vec_full_table_name = f\"{schema_name}.{vec_table_name}\"\n",
        "db_url = os.environ.get(\"PG_CONN_STR\")\n",
        "with psycopg.connect(db_url) as conn:\n",
        "    with conn.cursor() as cur:\n",
        "        cur.execute(f\"TRUNCATE TABLE {vec_full_table_name} RESTART IDENTITY CASCADE;\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nk10BgqB5Gim"
      },
      "source": [
        "## Compare with the Default Keyword Search"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "95M4KAjQ5Fqn",
        "outputId": "030ef81c-9a8a-4e23-d958-7bf367f66f14"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "result {'documents': []}\n"
          ]
        }
      ],
      "source": [
        "#use default store\n",
        "document_store = PgvectorDocumentStore()\n",
        "\n",
        "retriever = PgvectorKeywordRetriever(document_store=document_store,top_k = 1)\n",
        "\n",
        "document_store.write_documents([\n",
        "    Document(id = \"1\" ,content=\"My name is Jean and I live in Paris.\"),\n",
        "    Document(id = \"2\", content=\"My name is Mark and I live in Berlin.\"),\n",
        "    Document(id = \"3\", content=\"My name is Giorgio and I live in Rome.\")\n",
        "])\n",
        "\n",
        "query = \"Do you think Jean, Mayank and Alex live in Paris?\"\n",
        "res = retriever.run(query=transformed_query)\n",
        "print(\"result\",res)\n",
        "\n",
        "#just delete the data\n",
        "schema_name = \"public\"\n",
        "vec_table_name = \"haystack_documents\"\n",
        "vec_full_table_name = f\"{schema_name}.{vec_table_name}\"\n",
        "db_url = os.environ.get(\"PG_CONN_STR\")\n",
        "with psycopg.connect(db_url) as conn:\n",
        "    with conn.cursor() as cur:\n",
        "        cur.execute(f\"TRUNCATE TABLE {vec_full_table_name} RESTART IDENTITY CASCADE;\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BPt4uPAM5ZRE"
      },
      "source": [
        "As you can see the results are empty."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.9.6 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.6"
    },
    "orig_nbformat": 4,
    "vscode": {
      "interpreter": {
        "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
