Haystack Streaming Text Generation

Haystack Streaming Text Generation

In this post, we're going to give our sample Retrieval Augmented Generation (RAG) Pipeline a bit of a makeover. Specifically, we’ll tweak it to “stream” results from multiple nodes, letting us show off what we’ve got as soon as it rolls in—no more waiting for everything to pile up at the end.

Right now, our RAG Pipeline has one big flaw: it waits until the entire result is ready before it lets the user see anything. Talk about suspense! Not the best user experience, eh? It's way better if we start showing results the moment we have something. Not only does it feel snappier, but the user also gets the illusion of speed. Magic!

As it stands, our pipeline first grabs all the documents that match the query, then the Large Language Model (LLM) uses them to craft a response. The hitch? Haystack pipelines are synchronous, so we don’t get a peek at the retrieved documents until the LLM finishes its grand finale. But for document retrieval, the user might actually want to start digging into the docs while the LLM is still generating its response. Multitasking at its finest!

Ideally, we want the node fetching the documents to hand them over right away before passing them off to the LLM. And wouldn’t it be great if the LLM could stream its response a few words at a time, so we can watch it unfold in real-time?

The good news is, we can make this happen using a combo of custom Haystack components and Haystack’s streaming callbacks. The catch? It’s a bit more complicated. I have it on good authority (from a Deepset employee on Haystack’s own discord) that Haystack might add asynchronous pipelines in the future, but for now, custom components and streaming callbacks are the best trick in the toolbox.

Haystack’s Streaming Callback

To make streaming work correctly with our RAG pipeline we’ll need to make changes to all three of our python files. As a reminder, in our last post, we broke our sample code into three files:

The first change we need to make is to the Generator Models (i.e. generator_model.py). We need to enable them to accept a streaming callback function that Haystack can use. Fortunately, Haystack’s Hugging Face components—whether you’re using the local model or the API—already support passing in a streaming callback function. That’s what we’ll be using under the hood.

However, we want our pipeline to handle this dynamically, regardless of whether the specific Haystack Generator in use supports streaming callbacks. If the generator doesn’t support them, it should simply default to synchronous behavior and return the result at the end of the pipeline.

To make this kind of flexible behavior possible, we’ll adjust our generator classes so that they always pass a default streaming callback function for Hugging Face models. If the user doesn’t request streaming, the callback will just be ignored. Here’s how we can modify the Abstract Hugging Face Generator class (which underlies both the local and API based generators for Hugging Face models) to achieve that:

class HuggingFaceModel(GeneratorModel, ABC):
    def __init__(self,
                 model_name: str = 'google/gemma-1.1-2b-it',
                 max_new_tokens: int = 500,
                 temperature: float = 0.6,
                 password: Optional[str] = None,
                 streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
                 verbose: bool = False) -> None:
        super().__init__(verbose)

        self._max_new_tokens: int = max_new_tokens
        self._temperature: float = temperature
        self._model_name: str = model_name
        self._streaming_callback: Optional[Callable[[StreamingChunk], None]] = streaming_callback

        if password is not None:
            hf_hub.login(password, add_to_git_credential=False)

    @property
    def streaming_callback(self) -> Optional[Callable[[StreamingChunk], None]]:
        return self._streaming_callback

    @streaming_callback.setter
    def streaming_callback(self, value: callable(StreamingChunk)) -> None:
        self._streaming_callback = value

    def generate(self, prompt: str) -> str:
        return self._model.run(prompt)

    def _default_streaming_callback_func(self, chunk: StreamingChunk):
        # This is a callback function that is used to stream the output of the generator.
        # If you are not using a streaming generator, you can ignore this method.
        if self._streaming_callback is not None:
            self._streaming_callback(chunk)

So, we added a parameter for a streaming callback in the initialization. Unfortunately, this isn’t going to be useful the way we currently have our RAG pipeline class. What we really need is a way to pass in the model without specifying a streaming callback and then let the RAG Pipeline add one later. Hugging Face does not support such an ability, so we’re going to cheat. We’re going to always give the Hugging Face component our “default_streaming_callback_func”. This function checks to see if we’ve given the class an actual streaming callback function. If we haven’t, it does nothing. If we have, it calls it.

This approach magically allows us to now create a property four all Hugging Face classes that we can just assign a streaming callback to. In fact, we can change which function is the streaming callback function on the fly now! Pretty neat, eh?

There are also some small changes to our HuggingFaceAPIModel and HuggingFaceLocalModel classes to allow them to pass in the streaming callback function to their initialization function as well. I won’t go over these and honestly, they aren’t needed for our purposes.

The upshot of these changes is that all Hugging Face models we create (if we do so via our generator model classes) will now have a way to specify a streaming callback. The RAG pipeline will utilize this new feature if we tell it to.

I did not add streaming callback functionality to our Google Gemini model class wrapper because the version of Haystack I’m on does not support such functionality. But Google Gemini is so blazing fast it won’t really matter that much. But that explains why I only added the streaming callback functionality to our Hugging Face models class wrapper.

I made some adjustments to the Document Processor file as well, though to be honest those were more code hygiene issues I’ve been meaning to get to. So, I’ll not go over those changes either as they aren’t truly relevant to our current purposes.

Building a “Streaming” Haystack Component

Haystack natively supports streaming callbacks for their Hugging Face component models. But what about the document retriever component? Haystack does not have any sort of ‘streaming callback’ function on those components. But what if we want to show off the retrieved documents while we want for the model to generator its own response?

One way to handle that is to simply write our own document retrieving component that wraps the Haystack PgvectorEmbeddingRetriever component but streams out the documents via print statements to the console before returning the documents to the next node in our pipeline. Here is what I built:

@component
class StreamingRetriever:
    def __init__(self, retriever: PgvectorEmbeddingRetriever):
        self.retriever = retriever

    @component.output_types(documents=List[Document])
    def run(self, query_embedding: List[float]) -> Dict[str, Any]:
        # Create a dictionary for the expected format if necessary
        documents = self.retriever.run(query_embedding=query_embedding)['documents']
        print_documents(documents)
        # Return a dictionary with documents
        return {"documents": documents}

The printdocuments command is taken straight from our old generateresponse method, so it may look familiar:

def print_documents(documents: List[Document]) -> None:
    for i, doc in enumerate(documents, 1):
        print(f"Document {i}:")
        print(f"Score: {doc.score}")
        if hasattr(doc, 'meta') and doc.meta:
            if 'title' in doc.meta:
                print(f"Title: {doc.meta['title']}")
            if 'section_num' in doc.meta:
                print(f"Section: {doc.meta['section_num']}")
        print(f"Content: {doc.content}")
        print("-" * 50)

We then need to adjust our pipeline to use this new component – but we only want to do that IF the user asks for streaming. Here is the adjusted code to build our RAG pipeline:

    def _create_rag_pipeline(self) -> None:
        self._setup_embedder()
        self._setup_generator()
        prompt_builder: PromptBuilder = PromptBuilder(template=self._prompt_template)

        rag_pipeline: Pipeline = Pipeline()

        # Add the query embedder and the prompt builder
        rag_pipeline.add_component("query_embedder", self._sentence_embedder)
        rag_pipeline.add_component("prompt_builder", prompt_builder)

        # If streaming is enabled, use the StreamingRetriever
        if self._can_stream():
            streaming_retriever: StreamingRetriever = StreamingRetriever(
                retriever=PgvectorEmbeddingRetriever(document_store=self._document_store, top_k=5))
            rag_pipeline.add_component("retriever", streaming_retriever)
        else:
            # Use the standard retriever if not streaming
            rag_pipeline.add_component("retriever",
                                       PgvectorEmbeddingRetriever(document_store=self._document_store, top_k=5))

        # Add the LLM component
        if isinstance(self._generator_model, gen.GeneratorModel):
            rag_pipeline.add_component("llm", self._generator_model.generator_component)
        else:
            rag_pipeline.add_component("llm", self._generator_model)

        if not self._can_stream():
            # Add the merger only when streaming is disabled
            rag_pipeline.add_component("merger", MergeResults())
            rag_pipeline.connect("retriever.documents", "merger.documents")
            rag_pipeline.connect("llm.replies", "merger.replies")

        # Connect the components for both streaming and non-streaming scenarios
        rag_pipeline.connect("query_embedder.embedding", "retriever.query_embedding")
        rag_pipeline.connect("retriever.documents", "prompt_builder.documents")
        rag_pipeline.connect("prompt_builder", "llm")

        # Set the pipeline instance
        self._rag_pipeline = rag_pipeline

Note the various places where I check if we’re asking for streaming or not via the “_can_stream()” method. If we are streaming, we use the StreamingRetriever in place of the PgvectorEmbeddingRetriever. And if we are not streaming, we add back on the final custom merger node to bring the documents and the LLM generated results together for output.

Here is the code for canstream():

    def _can_stream(self) -> bool:
        return (self._use_streaming
                and self._generator_model is not None
                and isinstance(self._generator_model, gen.GeneratorModel)
                and hasattr(self._generator_model, 'streaming_callback'))

This probably isn’t the best way to do this. For one thing, it turns off streaming the documents out to the user for no reason other than that there is no LLM streaming callback available. But this does keep things simple for out toy example. Basically, I setup the streaming version of the pipeline but only if the following criteria are matched:

  1. The user specified they want it via the use_streaming parameter
  2. We actually have a generator model
  3. The generator model is of type GeneratorModel (i.e. we’re not passing in a direct Haystack generator. In the past we supported that, but currently the code doesn’t work with anything but our GeneratorModel class or subclasses of it.)
  4. The model has an attribute of ‘streaming_callback’.

I also changed the generate_response method to NOT show the final response if it is to be streamed. I’ll not show off that code here as it isn’t that interesting.

Results

The end results are exactly what I hoped for! The documents display to the user as soon as they are retrieved from the PostgreSQL datastore and then we can watch the LLM generate its response one word at a time. This is a much-improved user experience!

SHARE


comments powered by Disqus

Follow Us

Latest Posts

subscribe to our newsletter