Text-to-SQL with LangChain (1) - Chain

Series - Text-to-SQL with LangChain

In this post, we’ll dive into how to set up Text-to-SQL with LangChain, guided by the Q&A over SQL+CSV Quickstart.

  • We’ll be using LangChain version 0.2.0.

    $ pip list|grep langchain
    langchain                0.2.0
    langchain-community      0.2.0
    langchain-core           0.2.0
    langchain-openai         0.1.7
    langchain-text-splitters 0.2.0
    

First, install the necessary libraries.

pip install --upgrade --quiet  langchain langchain-community langchain-openai

Next, set the environment variable OPENAI_API_KEY. Create a .openai file in your working directory containing your OpenAI API key, and run the following code.

import os

# Set API key as an environment variable
with open('.openai') as f:
    os.environ['OPENAI_API_KEY'] = f.read().strip()

For testing, we’ll use the Chinook database, a sample database that simulates a digital media store. It includes tables such as:

  1. Artists: Artist information
  2. Albums: Album information
  3. Tracks: Track (song) information
  4. Customers: Customer information
  5. Invoices: Invoice information
  6. Employees: Employee information

To create the Chinook database, download this file and place it in your working directory. Then, run the following code.

import sqlite3

# Database file name
db_file = 'Chinook.db'
sql_file = 'Chinook_Sqlite.sql'

# Open database connection
with sqlite3.connect(db_file) as conn:
    cursor = conn.cursor()
    
    # Read SQL file
    with open(sql_file, 'r', encoding='utf-8') as file:
        sql_script = file.read()

    # Execute SQL script
    cursor.executescript(sql_script)
    
    # Commit and save database
    conn.commit()

This will create a SQLite database named Chinook.db in your working directory.

Let’s test the connection to the created SQLite database.

from langchain_community.utilities import SQLDatabase

# Specify the database to connect to
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

print(db.dialect)  # Database type
print(db.get_usable_table_names())  # List of usable table names

# Execute an SQL statement
db.run("SELECT * FROM Artist LIMIT 10;")
  • Expected output:

    sqlite
    ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
    
    "[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
    

Using create_sql_query_chain, we can create a chain that converts natural language queries into SQL statements.

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)

Invoke the created chain as follows:

response = chain.invoke({"question": "How many employees are there?"})
print(response)
  • Expected output:

    SELECT COUNT("EmployeeId") AS "TotalEmployees" FROM "Employee"
    

Execute this SQL statement in the database to get the result.

db.run(response)
  • Expected output:

    '[(8,)]'
    

By chaining QuerySQLDataBaseTool, we can automatically execute the generated query and get the result.

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there?"})
  • Expected output:

    '[(8,)]'
    
Caution
For example, if you input {"question": "Delete all data from the Album table."}, the query to delete all data from the Album table will be executed automatically. Ensure proper permission management when using this in production.

The chain we created above directly returns the SQL query result. To output the answer in natural language based on the query result, create a chain like this:

from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Based on the user's question, the corresponding SQL query, and the SQL result, answer the user's question.

Question: {question}
SQLQuery: {query}
SQLResult: {result}
Answer: """
)

# Chain to generate answers using the above prompt
answer = answer_prompt | llm | StrOutputParser()

# Chain to pass the SQL query result to the answer chain
chain = (
    RunnablePassthrough.assign(
        query=write_query
    ).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there?"})
  • Expected output:

    'There are a total of 8 employees.'
    

Let’s review the code for the create_sql_query_chain function to understand its workings (link).

def create_sql_query_chain(
    llm: BaseLanguageModel,
    db: SQLDatabase,
    prompt: Optional[BasePromptTemplate] = None,
    k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
    """Create a chain that generates SQL queries.

    *Security Note*: This chain generates SQL queries for the given database.

        The SQLDatabase class provides a get_table_info method that can be used
        to get column information as well as sample data from the table.

        To mitigate risk of leaking sensitive data, limit permissions
        to read and scope to the tables that are needed.

        Optionally, use the SQLInputWithTables input type to specify which tables
        are allowed to be accessed.

        Control access to who can submit requests to this chain.

        See https://python.langchain.com/docs/security for more information.

    Args:
        llm: The language model to use.
        db: The SQLDatabase to generate the query for.
        prompt: The prompt to use. If none is provided, will choose one
            based on dialect. Defaults to None. See Prompt section below for more.
        k: The number of results per select statement to return. Defaults to 5.

    Returns:
        A chain that takes in a question and generates a SQL query that answers
        that question.

    Example:

        .. code-block:: python

            # pip install -U langchain langchain-community langchain-openai
            from langchain_openai import ChatOpenAI
            from langchain.chains import create_sql_query_chain
            from langchain_community.utilities import SQLDatabase

            db = SQLDatabase.from_uri("sqlite:///Chinook.db")
            llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
            chain = create_sql_query_chain(llm, db)
            response = chain.invoke({"question": "How many employees are there"})

    Prompt:
        If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables:
            * input: The user question plus suffix "\nSQLQuery: " is passed here.
            * top_k: The number of results per select statement (the `k` argument to
                this function) is passed in here.
            * table_info: Table definitions and sample rows are passed in here. If the
                user specifies "table_names_to_use" when invoking chain, only those
                will be included. Otherwise, all tables are included.
            * dialect (optional): If dialect input variable is in prompt, the db
                dialect will be passed in here.

        Here's an example prompt:

        .. code-block:: python

            from langchain_core.prompts import PromptTemplate

            template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
            Use the following format:

            Question: "Question here"
            SQLQuery: "SQL Query to run"
            SQLResult: "Result of the SQLQuery"
            Answer: "Final answer here"

            Only use the following tables:

            {table_info}.

            Question: {input}'''
            prompt = PromptTemplate.from_template(template)
    """  # noqa: E501
    if prompt is not None:
        prompt_to_use = prompt
    elif db.dialect in SQL_PROMPTS:
        prompt_to_use = SQL_PROMPTS[db.dialect]
    else:
        prompt_to_use = PROMPT
    if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
        raise ValueError(
            f"Prompt must have input variables: 'input', 'top_k', "
            f"'table_info'. Received prompt with input variables: "
            f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
        )
    if "dialect" in prompt_to_use.input_variables:
        prompt_to_use = prompt_to_use.partial(dialect=db.dialect)

    inputs = {
        "input": lambda x: x["question"] + "\nSQLQuery: ",
        "table_info": lambda x: db.get_table_info(
            table_names=x.get("table_names_to_use")
        ),
    }
    return (
        RunnablePassthrough.assign(**inputs)  # type: ignore
        | (
            lambda x: {
                k: v
                for k, v in x.items()
                if k not in ("question", "table_names_to_use")
            }
        )
        | prompt_to_use.partial(top_k=str(k))
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
        | _strip
    )
  1. If no prompt option is specified, SQL_PROMPTS[db.dialect] is used. In this case, db.dialect is sqlite, so the following prompt template is used.

    SQLITE_PROMPT = PromptTemplate(
        input_variables=["input", "table_info", "top_k"],
        template=_sqlite_prompt + PROMPT_SUFFIX,
    )
    
    _sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
    Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
    Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
    Pay attention to use date('now') function to get the current date, if the question involves "today".
    
    Use the following format:
    
    Question: Question here
    SQLQuery: SQL Query to run
    SQLResult: Result of the SQLQuery
    Answer: Final answer here
    
    """
    
    PROMPT_SUFFIX = """Only use the following tables:
    {table_info}
    
    Question: {input}"""
    
    • template

      You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
      Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
      Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
      Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
      Pay attention to use date('now') function to get the current date, if the question involves "today".
      
      Use the following format:
      
      Question: Question here
      SQLQuery: SQL Query to run
      SQLResult: Result of the SQLQuery
      Answer: Final answer here
      
      Only use the following tables:
      {table_info}
      
      Question: {input}
      
  2. The llm stops processing once it encounters \nSQLResult:. The input value contains “the value of the question key + \nSQLQuery:”, so only the corresponding SQL statement is expected to be output.

  3. db.get_table_info outputs the table definitions and 3 rows of sample data.

    CREATE TABLE "Album" (
    	"AlbumId" INTEGER NOT NULL, 
    	"Title" NVARCHAR(160) NOT NULL, 
    	"ArtistId" INTEGER NOT NULL, 
    	PRIMARY KEY ("AlbumId"), 
    	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
    )
    
    /*
    3 rows from Album table:
    AlbumId	Title	ArtistId
    1	For Those About To Rock We Salute You	1
    2	Balls to the Wall	2
    3	Restless and Wild	2
    */
    
  4. If the query doesn’t specify the number of rows, the default is 5. You can change this number with the k option.

  5. The generated chain allows you to specify the table_names_to_use key in the input dictionary to limit the tables used, saving token count.

Caution
Even if you specify the table_names_to_use key, the LLM might infer table names from the query and access tables outside the specified list. Do not use the table_names_to_use key as a substitute for proper permission management.

Here’s a summary of the code to query a database in natural language and get the answer in natural language.

import os
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

# Set API key as an environment variable
with open(".openai") as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()

# Specify the database to connect to
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# LLM settings
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# Tool to execute SQL statements
execute_query = QuerySQLDataBaseTool(db=db)

# Chain to generate SQL statements
write_query = create_sql_query_chain(llm, db)

# Prompt template to generate answers in natural language
answer_prompt = PromptTemplate.from_template(
    """Based on the user's question, the corresponding SQL query, and the SQL result, answer the user's question.

Question: {question}
SQLQuery: {query}
SQLResult: {result}
Answer: """
)

# Chain to generate answers using the above prompt
answer = answer_prompt | llm | StrOutputParser()

# Chain to pass the SQL query result to the answer chain
chain = (
    RunnablePassthrough.assign(
        query=write_query
    ).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

# Execute the query
chain.invoke({"question": "How many employees are there?"})

Related Content