Setting up a hallucination checking agent with LangGraph

What this tutorial covers:

  • Extracting claims from AI-generated | politician generated | other unreliable source generated text
  • Using Exa to retrieve relevant sources that may support or refute these claims
  • Use an LLM to analyze the claims against the relevant real sources and assign a confidence score

Guide

One of the significant concerns with AI language models is their tendency to produce hallucinations—statements that sound plausible but are not grounded in factual data. This can erode trust in AI systems, especially when accurate information is critical.

Exa is a powerful tool to mitigate this. In this tutorial, we demonstrate how we can use Exa to cross-reference AI-generated content with real-world data to verify claim authenticity and identify hallucinations.

Set Up

Let's kick things off by importing the necessary libraries and setting up our LLM and Exa search retriever.

import os
import re
import json
from typing import Dict, Any, List, Annotated
from pydantic import BaseModel
from langchain_core.tools import StructuredTool
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_exa import ExaSearchRetriever
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_anthropic import ChatAnthropic

# Check for API keys
assert os.getenv("EXA_API_KEY"), "Please set the EXA_API_KEY environment variable"
assert os.getenv("ANTHROPIC_API_KEY"), "Please set the ANTHROPIC_API_KEY environment variable"

# Set up the LLM (ChatAnthropic)
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0)

Claim and Source Identification with Exa

Now, let's define a function to extract factual claims from the input text using our LLM. If it fails, we'll fall back on regex.

def extract_claims(text: str) -> List[str]:
    """Extract factual claims from the text using an LLM."""
    system_message = SystemMessage(content="""
    You are an expert at extracting claims from text.
    Your task is to identify and list all claims present, true or false,
    in the given text. Each claim should be a single, verifiable statement.
    Consider various forms of claims, including assertions, statistics, and
    quotes. Do not skip any claims, even if they seem obvious. Do not include in the list 'The text contains a claim that needs to be checked for hallucinations' - this is not a claim.
    Present the claims as a JSON array of strings, and do not include any additional text.
    """)

    human_message = HumanMessage(content=f"Extract factual claims from this text: {text}")
    response = llm.invoke([system_message, human_message])

    try:
        claims = json.loads(response.content)
        if not isinstance(claims, list):
            raise ValueError("Response is not a list")
    except (json.JSONDecodeError, ValueError):
        # Fallback to regex extraction if LLM response is not valid JSON
        claims = extract_claims_regex(text)
    
    return claims

def extract_claims_regex(text: str) -> List[str]:
    """Fallback function to extract claims using regex."""
    pattern = r'([A-Z][^.!?]*?[.!?])'
    matches = re.findall(pattern, text)
    return [match.strip()+'.' for match in matches]

We'll use Exa to search for relevant sources that might support or refute each claim.

def exa_search(query: str) -> List[str]:
    """Function to retrieve usable documents for AI assistant."""
    search = ExaSearchRetriever(k=5, text=True, use_autoprompt=False)

    print("Query: ", query)

    document_prompt = PromptTemplate.from_template(
        """
        <source>
            <url>{url}</url>
            <text>{text}</text>
        </source>
        """
    )

    parse_info = RunnableLambda(
        lambda document: {
            "url": document.metadata["url"],
            "text": document.page_content or "No text available",
        }
    )

    document_chain = (parse_info | document_prompt)
    search_chain = search | document_chain.map()
    documents = search_chain.invoke(query+".\n Here is a web page to help verify this claim:")

    print("Documents: ", documents)
    
    return [str(doc) for doc in documents]

Claim Verification

We'll use an LLM to assess whether the combined sources support or refute the claim. We're bundling all sources together to minimize LLM calls.

def verify_claim(claim: str, sources: List[str]) -> Dict[str, Any]:
    """Verify a single claim using combined Exa search sources."""
    if not sources:
        # If no sources are returned, default to insufficient information
        return {
            "claim": claim,
            "assessment": "Insufficient information",
            "confidence_score": 0.5,
            "supporting_sources": [],
            "refuting_sources": []
        }
    
    # Combine the sources into one text
    combined_sources = "\n\n".join(sources)
    
    system_message = SystemMessage(content="""
    You are an expert fact-checker.
    Given a claim and a set of sources, determine whether the claim is supported, refuted, or if there is insufficient information in the sources to make a determination.
    For your analysis, consider all the sources collectively.
    Provide your answer as a JSON object with the following structure:
    {
        "claim": "...",
        "assessment": "supported" or "refuted" or "Insufficient information",
        "confidence_score": a number between 0 and 1 (1 means fully confident the claim is true, 0 means fully confident the claim is false),
        "supporting_sources": [list of sources that support the claim],
        "refuting_sources": [list of sources that refute the claim]
    }
    Do not include any additional text.
    """)
    
    human_message = HumanMessage(content=f"""
    Claim: "{claim}"
    
    Sources:
    {combined_sources}
    
    Based on the above sources, assess the claim.
    """)
    
    response = llm.invoke([system_message, human_message])
    
    try:
        result = json.loads(response.content)
        if not isinstance(result, dict):
            raise ValueError("Response is not a JSON object")
    except (json.JSONDecodeError, ValueError):
        # If parsing fails, default to insufficient information
        result = {
            "claim": claim,
            "assessment": "Insufficient information",
            "confidence_score": 0.5,
            "supporting_sources": [],
            "refuting_sources": []
        }
    
    return result

Let's wrap it all up with our hallucination check workflow tool.

def hallucination_check(text: str) -> Dict[str, Any]:
    """Check a given text for hallucinations using Exa search."""
    claims = extract_claims(text)
    claim_verifications = []

    for claim in claims:
        sources = exa_search(claim)
        verification_result = verify_claim(claim, sources)
        claim_verifications.append(verification_result)

    return {
        "claims": claim_verifications
    }

def hallucination_check_tool(text: str) -> Dict[str, Any]:
    """Assess the given text for hallucinations using Exa search."""
    return hallucination_check(text)

structured_tool = StructuredTool.from_function(
    func=hallucination_check_tool,
    name="hallucination_check",
    description="Assess the given text for hallucinations using Exa search."
)

Run the Workflow

Time to put it all together. We'll use StateGraph to set up a simple workflow and run a quick example about the Eiffel Tower.

class State(BaseModel):
    messages: Annotated[List, add_messages]
    analysis_result: Dict[str, Any] = {}

def call_model(state: State):
    # Simulate the assistant calling the tool
    return {"messages": state.messages + [AIMessage(content="Use hallucination_check tool", additional_kwargs={"tool_calls": [{"type": "function", "function": {"name": "hallucination_check"}}]})]}

def run_tool(state: State):
    text_to_check = next((m.content for m in reversed(state.messages) if isinstance(m, HumanMessage)), "")
    tool_output = structured_tool.invoke(text_to_check)
    return {"messages": state.messages + [AIMessage(content=str(tool_output))], "analysis_result": tool_output}

def use_analysis(state: State) -> str:
    return "tools"

workflow = StateGraph(State)

workflow.add_node("agent", call_model)
workflow.add_node("tools", run_tool)
workflow.add_node("process_result", lambda x: x)

workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", use_analysis, {
    "tools": "tools"
})
workflow.add_edge("tools", "process_result")
workflow.add_edge("process_result", END)

graph = workflow.compile()

# Example usage
if __name__ == "__main__":
    initial_state = State(messages=[
        SystemMessage(content="You are a helpful assistant."),
        HumanMessage(content="Check this text for hallucinations: The Eiffel Tower, an iconic iron lattice structure located in Paris, was originally constructed as a giant sundial in 1822.")
    ])

    final_state = graph.invoke(initial_state)

    print("Workflow executed successfully")
    print("Final state:")
    print("Messages:")
    for message in final_state["messages"]:
        print(f"{message.__class__.__name__}: {message.content[:100]}...")  # Print first 100 characters

    print("\nAnalysis Result:")
    for claim_info in final_state["analysis_result"]["claims"]:
        print(f"Claim: {claim_info['claim']}")
        print(f"Assessment: {claim_info['assessment']}")
        print(f"Confidence Score: {claim_info['confidence_score']}")
        print("Supporting Sources:")
        for source in claim_info['supporting_sources']:
            print(f"- {source[:100]}...")  # Print first 100 characters
        print("Refuting Sources:")
        for source in claim_info['refuting_sources']:
            print(f"- {source[:100]}...")
        print()

Sample output:

Workflow executed successfully
Final state:
Messages:
SystemMessage: You are a helpful assistant....
HumanMessage: Check this text for hallucinations: The Eiffel Tower, an iconic iron lattice structure located in Pa...
AIMessage: Use hallucination_check tool...
AIMessage: {'claims': [{'claim': 'The Eiffel Tower is an iconic iron lattice structure', 'assessment': 'support...

Analysis Result:
Claim: The Eiffel Tower is an iconic iron lattice structure
Assessment: supported
Confidence Score: 1
Supporting Sources:
- https://www.toureiffel.paris/en/news/130-years/what-eiffel-tower-made...
- https://thechalkface.net/resources/melting_the_eiffel_tower.pdf...
- https://datagenetics.com/blog/april22016/index.html...
- https://engineering.purdue.edu/MSE/aboutus/gotmaterials/Buildings/patel.html...
- https://www.toureiffel.paris/en/news/130-years/how-long-can-tower-last...
Refuting Sources:

Claim: The Eiffel Tower is located in Paris
Assessment: supported
Confidence Score: 1
Supporting Sources:
- https://hoaxes.org/weblog/comments/is_the_eiffel_tower_copyrighted...
- https://www.toureiffel.paris/en...
- http://www.eiffeltowerguide.com/...
- https://www.toureiffel.paris/en/the-monument...
Refuting Sources:

Claim: The Eiffel Tower was originally constructed as a giant sundial
Assessment: refuted
Confidence Score: 0.05
Supporting Sources:
Refuting Sources:
- https://www.whycenter.com/why-was-the-eiffel-tower-built/...
- https://www.sciencekids.co.nz/sciencefacts/engineering/eiffeltower.html...
- https://corrosion-doctors.org/Landmarks/eiffel-history.htm...

Claim: The Eiffel Tower was constructed in 1822
Assessment: refuted
Confidence Score: 0
Supporting Sources:
Refuting Sources:
- https://www.eiffeltowerfacts.org/eiffel-tower-history/...
- https://www.whycenter.com/why-was-the-eiffel-tower-built/...
- https://www.sciencekids.co.nz/sciencefacts/engineering/eiffeltower.html...

And there you have it! This shows how Exa + an LLM can be used to identify which claims are hallucinations and determine the validity of information.

Exa can be integrated into applications as a powerful tool addressing one of the primary concerns with AI language models.

Full Code

import os
import re
import json
from typing import Dict, Any, List, Annotated
from pydantic import BaseModel
from langchain_core.tools import StructuredTool
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_exa import ExaSearchRetriever
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_anthropic import ChatAnthropic

# Check for API keys
assert os.getenv("EXA_API_KEY"), "Please set the EXA_API_KEY environment variable"
assert os.getenv("ANTHROPIC_API_KEY"), "Please set the ANTHROPIC_API_KEY environment variable"

# Set up the LLM (ChatAnthropic)
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0)

def exa_search(query: str) -> List[str]:
    """Function to retrieve usable documents for AI assistant."""
    search = ExaSearchRetriever(k=5, text=True, use_autoprompt=False)

    document_prompt = PromptTemplate.from_template(
        """
        <source>
            <url>{url}</url>
            <text>{text}</text>
        </source>
        """
    )

    parse_info = RunnableLambda(
        lambda document: {
            "url": document.metadata["url"],
            "text": document.page_content or "No text available",

        }
    )

    document_chain = (parse_info | document_prompt)
    search_chain = search | document_chain.map()
    documents = search_chain.invoke(query+".\n Here is a web page to help verify this claim:")
    
    return [str(doc) for doc in documents]

def extract_claims(text: str) -> List[str]:
    """Extract factual claims from the text using an LLM."""
    system_message = SystemMessage(content="""
    You are an expert at extracting claims from text.
    Your task is to identify and list all claims present, true or false,
    in the given text. Each claim should be a single, verifiable statement.
    Consider various forms of claims, including assertions, statistics, and
    quotes. Do not skip any claims, even if they seem obvious. Do not include in the list 'The text contains a claim that needs to be checked for hallucinations' - this is not a claim.
    Present the claims as a JSON array of strings, and do not include any additional text.
    """)

    human_message = HumanMessage(content=f"Extract factual claims from this text: {text}")
    response = llm.invoke([system_message, human_message])

    try:
        claims = json.loads(response.content)
        if not isinstance(claims, list):
            raise ValueError("Response is not a list")
    except (json.JSONDecodeError, ValueError):
        # Fallback to regex extraction if LLM response is not valid JSON
        claims = extract_claims_regex(text)
    
    return claims

def extract_claims_regex(text: str) -> List[str]:
    """Fallback function to extract claims using regex."""
    pattern = r'([A-Z][^.!?]*?[.!?])'
    matches = re.findall(pattern, text)
    return [match.strip()+'.' for match in matches]

def verify_claim(claim: str, sources: List[str]) -> Dict[str, Any]:
    """Verify a single claim using combined Exa search sources."""
    if not sources:
        # If no sources are returned, default to insufficient information
        return {
            "claim": claim,
            "assessment": "Insufficient information",
            "confidence_score": 0.5,
            "supporting_sources": [],
            "refuting_sources": []
        }
    
    # Combine the sources into one text
    combined_sources = "\n\n".join(sources)
    
    system_message = SystemMessage(content="""
    You are an expert fact-checker.
    Given a claim and a set of sources, determine whether the claim is supported, refuted, or if there is insufficient information in the sources to make a determination.
    For your analysis, consider all the sources collectively.
    Provide your answer as a JSON object with the following structure:
    {
        "claim": "...",
        "assessment": "supported" or "refuted" or "Insufficient information",
        "confidence_score": a number between 0 and 1 (1 means fully confident the claim is true, 0 means fully confident the claim is false),
        "supporting_sources": [list of sources that support the claim],
        "refuting_sources": [list of sources that refute the claim]
    }
    Do not include any additional text.
    """)
    
    human_message = HumanMessage(content=f"""
    Claim: "{claim}"
    
    Sources:
    {combined_sources}
    
    Based on the above sources, assess the claim.
    """)
    
    response = llm.invoke([system_message, human_message])
    
    try:
        result = json.loads(response.content)
        if not isinstance(result, dict):
            raise ValueError("Response is not a JSON object")
    except (json.JSONDecodeError, ValueError):
        # If parsing fails, default to insufficient information
        result = {
            "claim": claim,
            "assessment": "Insufficient information",
            "confidence_score": 0.5,
            "supporting_sources": [],
            "refuting_sources": []
        }
    
    return result

def hallucination_check(text: str) -> Dict[str, Any]:
    """Check a given text for hallucinations using Exa search."""
    claims = extract_claims(text)
    claim_verifications = []

    for claim in claims:
        sources = exa_search(claim)
        verification_result = verify_claim(claim, sources)
        claim_verifications.append(verification_result)

    return {
        "claims": claim_verifications
    }

def hallucination_check_tool(text: str) -> Dict[str, Any]:
    """Assess the given text for hallucinations using Exa search."""
    return hallucination_check(text)

structured_tool = StructuredTool.from_function(
    func=hallucination_check_tool,
    name="hallucination_check",
    description="Assess the given text for hallucinations using Exa search."
)

class State(BaseModel):
    messages: Annotated[List, add_messages]
    analysis_result: Dict[str, Any] = {}

def call_model(state: State):
    # Simulate the assistant calling the tool
    return {"messages": state.messages + [AIMessage(content="Use hallucination_check tool", additional_kwargs={"tool_calls": [{"type": "function", "function": {"name": "hallucination_check"}}]})]}

def run_tool(state: State):
    text_to_check = next((m.content for m in reversed(state.messages) if isinstance(m, HumanMessage)), "")
    tool_output = structured_tool.invoke(text_to_check)
    return {"messages": state.messages + [AIMessage(content=str(tool_output))], "analysis_result": tool_output}

def use_analysis(state: State) -> str:
    return "tools"

workflow = StateGraph(State)

workflow.add_node("agent", call_model)
workflow.add_node("tools", run_tool)
workflow.add_node("process_result", lambda x: x)

workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", use_analysis, {
    "tools": "tools"
})
workflow.add_edge("tools", "process_result")
workflow.add_edge("process_result", END)

graph = workflow.compile()

# Example usage
if __name__ == "__main__":
    initial_state = State(messages=[
        SystemMessage(content="You are a helpful assistant."),
        HumanMessage(content="Check this text for hallucinations: The Eiffel Tower, an iconic iron lattice structure located in Paris, was originally constructed as a giant sundial in 1822.")
    ])

    final_state = graph.invoke(initial_state)

    print("Workflow executed successfully")
    print("Final state:")
    print("Messages:")
    for message in final_state["messages"]:
        print(f"{message.__class__.__name__}: {message.content[:100]}...")  # Print first 100 characters

    print("\nAnalysis Result:")
    for claim_info in final_state["analysis_result"]["claims"]:
        print(f"Claim: {claim_info['claim']}")
        print(f"Assessment: {claim_info['assessment']}")
        print(f"Confidence Score: {claim_info['confidence_score']}")
        print("Supporting Sources:")
        for source in claim_info['supporting_sources']:
            print(f"- {source[:100]}...")  # Print first 100 characters
        print("Refuting Sources:")
        for source in claim_info['refuting_sources']:
            print(f"- {source[:100]}...")
        print()