{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cb7RLf9gpEoN"
      },
      "source": [
        "# RAG pipeline using FastEmbed for embeddings generation\n",
        "\n",
        "[FastEmbed](https://qdrant.github.io/fastembed/) is a lightweight, fast, Python library built for embedding generation, maintained by Qdrant. \n",
        "It is suitable for generating embeddings efficiently and fast on CPU-only machines.\n",
        "\n",
        "In this notebook, we will use FastEmbed-Haystack integration to generate embeddings for indexing and RAG.\n",
        "\n",
        "**Haystack Useful Sources**\n",
        "\n",
        "* [Docs](https://docs.haystack.deepset.ai/docs/intro)\n",
        "* [Tutorials](https://haystack.deepset.ai/tutorials)\n",
        "* [Other Cookbooks](https://github.com/deepset-ai/haystack-cookbook)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tnSq1XK_ovZV"
      },
      "outputs": [],
      "source": [
        "!pip install fastembed-haystack qdrant-haystack wikipedia transformers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x8Bpy1ri_Ipx"
      },
      "source": [
        "## Download contents and create docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "2NIxMNQTzcfV"
      },
      "outputs": [],
      "source": [
        "favourite_bands=\"\"\"Audioslave\n",
        "Green Day\n",
        "Muse (band)\n",
        "Foo Fighters (band)\n",
        "Nirvana (band)\"\"\".split(\"\\n\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "7FpCSSnUzuuP"
      },
      "outputs": [],
      "source": [
        "import wikipedia\n",
        "from haystack.dataclasses import Document\n",
        "\n",
        "raw_docs=[]\n",
        "for title in favourite_bands:\n",
        "    page = wikipedia.page(title=title, auto_suggest=False)\n",
        "    doc = Document(content=page.content, meta={\"title\": page.title, \"url\":page.url})\n",
        "    raw_docs.append(doc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DLiNhYKV_g8u"
      },
      "source": [
        "## Clean, split and index documents on Qdrant"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "a1taDmfx1HCM"
      },
      "outputs": [],
      "source": [
        "from haystack_integrations.document_stores.qdrant import QdrantDocumentStore\n",
        "from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter\n",
        "from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder\n",
        "from haystack.document_stores.types import DuplicatePolicy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "eAKP9icf1Inj"
      },
      "outputs": [],
      "source": [
        "document_store = QdrantDocumentStore(\n",
        "    \":memory:\",\n",
        "    embedding_dim =384,\n",
        "    recreate_index=True,\n",
        "    return_embedding=True,\n",
        "    wait_result_from_api=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6rbZVTxR907n"
      },
      "outputs": [],
      "source": [
        "cleaner = DocumentCleaner()\n",
        "splitter = DocumentSplitter(split_by='period', split_length=3)\n",
        "splitted_docs = splitter.run(cleaner.run(raw_docs)[\"documents\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "493"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "len(splitted_docs[\"documents\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### FastEmbed Document Embedder\n",
        "\n",
        "Here we are initializing the FastEmbed Document Embedder and using it to generate embeddings for the documents.\n",
        "We are using a small and good model, `BAAI/bge-small-en-v1.5` and specifying the `parallel` parameter to 0 to use all available CPU cores for embedding generation.\n",
        "\n",
        "⚠️ If you are running this notebook on Google Colab, please note that Google Colab only provides 2 CPU cores, so the embedding generation could be not as fast as it can be on a standard machine.\n",
        "\n",
        "For more information on FastEmbed-Haystack integration, please refer to the [documentation](https://docs.haystack.deepset.ai/docs/fastembeddocumentembedder) and [API reference](https://docs.haystack.deepset.ai/reference/fastembed-embedders)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U3s_68uw9FSW"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a0d7dbcf196047dfa767e6e6b78374b2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 148034.26it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 32458.07it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 223365.30it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 55758.84it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 81884.46it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 140853.49it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 105443.40it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 112014.05it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 76260.07it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 123766.35it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 63443.25it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 55431.33it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 82782.32it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 57368.90it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 9792.15it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 8983.52it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 10585.74it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 59634.65it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 46260.71it/s]\n",
            "Fetching 9 files: 100%|██████████| 9/9 [00:00<00:00, 36900.04it/s]\n",
            "Calculating embeddings: 100%|██████████| 493/493 [00:35<00:00, 13.73it/s]\n"
          ]
        }
      ],
      "source": [
        "document_embedder = FastembedDocumentEmbedder(model=\"BAAI/bge-small-en-v1.5\", parallel = 0, meta_fields_to_embed=[\"title\"])\n",
        "documents_with_embeddings = document_embedder.run(splitted_docs[\"documents\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "dyBwOzM-9Tqm"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "500it [00:00, 4262.26it/s]             \n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "493"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "document_store.write_documents(documents_with_embeddings.get(\"documents\"), policy=DuplicatePolicy.OVERWRITE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IdsqgLzE11_P"
      },
      "source": [
        "## RAG Pipeline using Qwen 2.5 7B"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HTlEGit3_XFk"
      },
      "outputs": [],
      "source": [
        "from haystack import Pipeline\n",
        "from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever\n",
        "from haystack_integrations.components.embedders.fastembed import FastembedTextEmbedder\n",
        "from haystack.components.generators.chat import HuggingFaceAPIChatGenerator\n",
        "from haystack.components.builders import ChatPromptBuilder\n",
        "from haystack.dataclasses import ChatMessage\n",
        "\n",
        "from pprint import pprint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Enter your Hugging Face Token\n",
        "# this is needed to use Zephyr, calling the free Hugging Face Inference API\n",
        "\n",
        "from getpass import getpass\n",
        "import os\n",
        "\n",
        "os.environ[\"HF_API_TOKEN\"] = getpass(\"Enter your Hugging Face Token: https://huggingface.co/settings/tokens \")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "94cUO0isUE7u"
      },
      "outputs": [],
      "source": [
        "generator = HuggingFaceAPIChatGenerator(api_type=\"serverless_inference_api\",\n",
        "                              api_params={\"model\": \"Qwen/Qwen2.5-7B-Instruct\",\n",
        "                                          \"provider\": \"together\"},\n",
        "                                    generation_kwargs={\"max_tokens\":500})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f4w6WMDsUZy1"
      },
      "outputs": [],
      "source": [
        "# define the template\n",
        "\n",
        "template = [ChatMessage.from_user(\"\"\"\n",
        "Using only the information contained in these documents return a brief answer (max 50 words).\n",
        "If the answer cannot be inferred from the documents, respond \\\"I don't know\\\".\n",
        "Documents:\n",
        "{% for doc in documents %}\n",
        "    {{ doc.content }}\n",
        "{% endfor %}\n",
        "Question: {{question}}\n",
        "Answer:\n",
        "\"\"\")]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uBIq_9Zz5iC5"
      },
      "outputs": [],
      "source": [
        "query_pipeline = Pipeline()\n",
        "# FastembedTextEmbedder is used to embed the query\n",
        "query_pipeline.add_component(\"text_embedder\", FastembedTextEmbedder(model=\"BAAI/bge-small-en-v1.5\", parallel = 0, prefix=\"query:\"))\n",
        "query_pipeline.add_component(\"retriever\", QdrantEmbeddingRetriever(document_store=document_store))\n",
        "query_pipeline.add_component(\"prompt_builder\", ChatPromptBuilder(template=template))\n",
        "query_pipeline.add_component(\"generator\", generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wry8LFP4pAm_"
      },
      "outputs": [],
      "source": [
        "# connect the components\n",
        "query_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n",
        "query_pipeline.connect(\"retriever.documents\", \"prompt_builder.documents\")\n",
        "query_pipeline.connect(\"prompt_builder\", \"generator\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sQnk_qCW890T"
      },
      "source": [
        "## Try the pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "NpUwlxIj6O0R"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Calculating embeddings: 100%|██████████| 1/1 [00:00<00:00, 24.62it/s]\n"
          ]
        }
      ],
      "source": [
        "question = \"Who is Dave Grohl?\"\n",
        "\n",
        "results = query_pipeline.run(\n",
        "    {   \"text_embedder\": {\"text\": question},\n",
        "        \"prompt_builder\": {\"question\": question},\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rcFsJVoqQ4zQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(' Dave Grohl is the founder and lead vocalist of the American rock band Foo '\n",
            " 'Fighters, which he formed in 1994 after the breakup of Nirvana, in which he '\n",
            " 'was the drummer.')\n"
          ]
        }
      ],
      "source": [
        "for d in results['generator']['replies']:\n",
        "  pprint(d.text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "authorship_tag": "ABX9TyOUHSz9GGV1Pg5thVug4r0Z",
      "gpuType": "T4",
      "include_colab_link": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
