DSPy Assertion: Computational Constraints for Self-Refining Language Model Pipelines
by Shangyin Tan
TLDR: We enable developers to define constraints on LM outputs to help LMs self-refine and help compiler choose better demos.
Chaining language model (LM) calls are essential to create power LM applications. DSPy is a fantastic framework for creating such pipelines declaratively, with automatic prompt tuning through compilation. This short blog summarizes how DSPy works in a nutshell. One emerging challenge is that the LM outputs are not always satisfactory. This brings up a huge problem in LM pipelines, where we may have multiple LM calls, and one misprediction can lead to a chain of errors or suboptimal results.
Let’s look at a concrete example of a multi-hop question answering pipeline:
class MultiHopQA(dspy.Module):
def __init__(self):
self.retrieve = dspy.Retrieve(k=3)
self.gen_query = dspy.ChainOfThought("context, question -> query")
self.gen_answer = dspy.ChainOfThought("context, question -> answer")
def forward(self, question):
context = []
for hop in range(2):
query = self.gen_query(context=context, question=question).query
context += self.retrieve(query).passages
return self.gen_answer(context=context, question=question)
In this example, we have a Retrieve
module that retrieves the top k
passages from a corpus given a query. Then, we have a ChainOfThought
module that generates a query to the retriever. Finally, we have another ChainOfThought
module that generates an answer given retrieved context and original question. One less optimal behavior we observe is that the query generator may generate a query that is too similar to the original question or previous queries, leading to redundant information retrieval. Although this is not a critical error that breaks the pipeline, it is still a suboptimal behavior that we want to avoid.
Now, we introduce a new weapon to the DSPy arsenal: LM Assertions. LM Assertions are constraints that we can define on the LM outputs. For example, we can define a constraint that the query generated by the gen_query
module should not be too similar to the original question. As we said, the behavior we want to avoid is not an absolute critical error - we can define the constraint as a soft constraint using dspy.Suggest(constraint, message)
. Here is how we can define the constraint:
class MultiHopQAWithAssertions(dspy.Module):
def forward(self, question):
context, queries = [], [question]
for hop in range(2):
query = self.generate_query(context=context, question=question).query
dspy.Suggest(is_query_distinct(query, queries), f"Query should be distinct from {queries}")
context += self.retrieve(query).passages
queries.append(query)
return self.generate_answer(context=context, question=question)
This simple one-liner dspy.Suggest
tells the DSPy executor that the query generated by gen_query
should be distinct from all previous queries. The is_query_distinct
function is a simple function that computes the similarity between two strings, but it can be any function that returns a boolean including another DSPy module.
Here’s how the suggestion and self-refinement come into play: if the executor finds that the constraint is violated, it will pause the current execution, add the failed constraint (past outputs, instruction to fix it) into the prompt, and backtrack to the previous failing module.
Suppose the query generated by gen_query
is too similar to the original question. In that case, the executor will add the following fields to the prompt and backtrack to the gen_query
module:
Updated Prompt with Suggestion Feedback
Context: …Question: …Past_Query: {previous attempt w/ errors}Instructions: Query should be distinct from {previous queries} …
With the updated prompt, the gen_query
module will generate a new query that is more likely to be distinct from the previous queries. Then, the executor will resume the execution and continue the pipeline.
We have conducted some preliminary experiments to test the effectiveness of LM Assertions. Here, we focus on intrinsic and extrinsic metrics. The intrinsics metric measures how adding LM Assertions and self-refinement affect the desired properties of the LM outputs, or simply, the number of suggestions passed. The extrinsic metric measures how this improves the performance of the downstream task. We measure both the golden passage retrieved and the correctness of the answer generated. The experiments are conducted on the HotPotQA dataset.
We also compare the performance of ZeroShot
and FewShot
compilation, and how adding assertions/suggestions to each scenario affects the performance. Here are the results:
Configuration | Suggestions Passed | Retrieval Score | Answer Correctness |
---|---|---|---|
ZeroShot –NoSuggest | 64.3\% | 34.7\% | 45.7\% |
ZeroShot –Suggest | 87.3\% | 39.3\% | 46.3\% |
FewShot –NoSuggest | 65.0\% | 40.3\% | 49.3\% |
FewShot –Suggest | 82.7\% | 42.0\% | 50.0\% |
For fancier stuff like expressing hard assertions using dspy.Assert
, the formal semantics of LM Assertions, different handlers for errors, more sophisticated experiments, and more, please check out our paper.