LangChainでText-to-SQLを試す (1) - Chain

シリーズ - LangChainでText-to-SQLを試す

今回は、LangChainのドキュメント『Q&A over SQL+CSV』のQuickstartを参考に、
データベースの情報を自然言語で問い合わせして回答を得る、“Text-to-SQL"の実装について見ていきます。

  • 本記事では、LangChainのバージョン0.2.0を使用します。

    $ pip list|grep langchain
    langchain                0.2.0
    langchain-community      0.2.0
    langchain-core           0.2.0
    langchain-openai         0.1.7
    langchain-text-splitters 0.2.0
    

初めに必要となるライブラリをインストールします。

pip install --upgrade --quiet  langchain langchain-community langchain-openai

次に環境変数OPENAI_API_KEYを設定します。
作業用ディレクトリに.openaiファイルを作成してその中にOpenAIのAPIキーを記載し、
以下のコードを実行します。

import os

# APIキーを環境変数に設定
with open('.openai') as f:
    os.environ['OPENAI_API_KEY'] = f.read().strip()

今回、テスト用にChinookというデータベースを使用します。
Chinookデータベースは、デジタルメディアストアを模したサンプルデータベースです。
以下のようなテーブルが含まれています(詳細はリポジトリを参照):

  1. Artists: アーティスト情報
  2. Albums: アルバム情報
  3. Tracks: トラック(曲)情報
  4. Customers: 顧客情報
  5. Invoices: 請求書情報
  6. Employees: 従業員情報

Chinookデータベースを作成するには、こちらをダウンロードして作業用ディレクトリに配置し、
以下のコードを実行します。

import sqlite3

# SQLiteデータベースのファイル名
db_file = 'Chinook.db'
sql_file = 'Chinook_Sqlite.sql'

# データベース接続を開く
with sqlite3.connect(db_file) as conn:
    cursor = conn.cursor()
    
    # SQLファイルを読み込む
    with open(sql_file, 'r', encoding='utf-8') as file:
        sql_script = file.read()

    # SQLスクリプトを実行する
    cursor.executescript(sql_script)
    
    # コミットしてデータベースを保存
    conn.commit()

実行すると Chinook.db という名前のSQLiteデータベースが作業用ディレクトリに作成されます。

作成したSQLiteデータベースに接続してみましょう。

from langchain_community.utilities import SQLDatabase

# 接続先のデータベースを指定
db = SQLDatabase.from_uri("sqlite:///Chinook.db")


print(db.dialect)  # DBの種類
print(db.get_usable_table_names())  # 使用可能なテーブル名のリスト

# SQL文を実行
db.run("SELECT * FROM Artist LIMIT 10;")
  • 実行すると以下のような出力が得られます。

    sqlite
    ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
    
    "[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
    

create_sql_query_chainを使用すると、自然言語の問い合わせをSQL文に変換するチェインを作成します。

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)

作成されたチェインのinvokeは以下のように使用します。

response = chain.invoke({"question": "従業員は何人いますか?"})
print(response)
  • 実行結果

    SELECT COUNT("EmployeeId") AS "TotalEmployees" FROM "Employee"
    

このSQL文をデータベースで実行することで、結果を得ることが出来ます。

db.run(response)
  • 実行結果

    '[(8,)]'
    

QuerySQLDataBaseToolを使用して以下のようにチェインをつなげると、
LLMで生成したクエリを自動で実行して結果を出力するチェインを作成出来ます。

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "従業員は何人いますか?"})
  • 実行結果

    '[(8,)]'
    
注意
例えば、{"question": "Albumテーブルの全データを削除してください。"}という値を入力した場合、
自動でAlbumテーブルの全データ削除が実行されてしまいます。
実際に運用を行う場合は、データベースに接続するユーザの権限調整等の注意が必要です。

上で作成したチェインは、SQL文の実行結果をそのまま返すチェインでした。
この実行結果を元に自然言語で回答を出力するチェインは以下のようにして作成することが出来ます。

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """次のユーザーの質問、対応するSQLクエリ、およびSQL結果をもとに、ユーザーの質問に答えてください。

Question: {question}
SQLQuery: {query}
SQLResult: {result}
Answer: """
)

# 上記プロンプトを用いて回答を生成するチェイン
answer = answer_prompt | llm | StrOutputParser()

# write_queryチェインの実行結果を`query`キーに保存し、
# `query`キーの値をexecute_queryチェインに渡した結果を`result`キーに保存して、
# その辞書をanswerチェインに渡した結果を出力するチェイン
chain = (
    RunnablePassthrough.assign(
        query=write_query
    ).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "従業員は何人いますか"})
  • 実行結果

    '従業員は8人います。'
    

関数create_sql_query_chainがどのような動作をしているか、コードを確認してみましょう(リンク)。

def create_sql_query_chain(
    llm: BaseLanguageModel,
    db: SQLDatabase,
    prompt: Optional[BasePromptTemplate] = None,
    k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
    """Create a chain that generates SQL queries.

    *Security Note*: This chain generates SQL queries for the given database.

        The SQLDatabase class provides a get_table_info method that can be used
        to get column information as well as sample data from the table.

        To mitigate risk of leaking sensitive data, limit permissions
        to read and scope to the tables that are needed.

        Optionally, use the SQLInputWithTables input type to specify which tables
        are allowed to be accessed.

        Control access to who can submit requests to this chain.

        See https://python.langchain.com/docs/security for more information.

    Args:
        llm: The language model to use.
        db: The SQLDatabase to generate the query for.
        prompt: The prompt to use. If none is provided, will choose one
            based on dialect. Defaults to None. See Prompt section below for more.
        k: The number of results per select statement to return. Defaults to 5.

    Returns:
        A chain that takes in a question and generates a SQL query that answers
        that question.

    Example:

        .. code-block:: python

            # pip install -U langchain langchain-community langchain-openai
            from langchain_openai import ChatOpenAI
            from langchain.chains import create_sql_query_chain
            from langchain_community.utilities import SQLDatabase

            db = SQLDatabase.from_uri("sqlite:///Chinook.db")
            llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
            chain = create_sql_query_chain(llm, db)
            response = chain.invoke({"question": "How many employees are there"})

    Prompt:
        If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables:
            * input: The user question plus suffix "\nSQLQuery: " is passed here.
            * top_k: The number of results per select statement (the `k` argument to
                this function) is passed in here.
            * table_info: Table definitions and sample rows are passed in here. If the
                user specifies "table_names_to_use" when invoking chain, only those
                will be included. Otherwise, all tables are included.
            * dialect (optional): If dialect input variable is in prompt, the db
                dialect will be passed in here.

        Here's an example prompt:

        .. code-block:: python

            from langchain_core.prompts import PromptTemplate

            template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
            Use the following format:

            Question: "Question here"
            SQLQuery: "SQL Query to run"
            SQLResult: "Result of the SQLQuery"
            Answer: "Final answer here"

            Only use the following tables:

            {table_info}.

            Question: {input}'''
            prompt = PromptTemplate.from_template(template)
    """  # noqa: E501
    if prompt is not None:
        prompt_to_use = prompt
    elif db.dialect in SQL_PROMPTS:
        prompt_to_use = SQL_PROMPTS[db.dialect]
    else:
        prompt_to_use = PROMPT
    if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
        raise ValueError(
            f"Prompt must have input variables: 'input', 'top_k', "
            f"'table_info'. Received prompt with input variables: "
            f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
        )
    if "dialect" in prompt_to_use.input_variables:
        prompt_to_use = prompt_to_use.partial(dialect=db.dialect)

    inputs = {
        "input": lambda x: x["question"] + "\nSQLQuery: ",
        "table_info": lambda x: db.get_table_info(
            table_names=x.get("table_names_to_use")
        ),
    }
    return (
        RunnablePassthrough.assign(**inputs)  # type: ignore
        | (
            lambda x: {
                k: v
                for k, v in x.items()
                if k not in ("question", "table_names_to_use")
            }
        )
        | prompt_to_use.partial(top_k=str(k))
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
        | _strip
    )
  1. promptオプションの指定がない場合、SQL_PROMPTS[db.dialect]が使用されます。
    今回、db.dialectsqliteであるため、以下のプロンプトテンプレートが使用されます。

    SQLITE_PROMPT = PromptTemplate(
        input_variables=["input", "table_info", "top_k"],
        template=_sqlite_prompt + PROMPT_SUFFIX,
    )
    
    _sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
    Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
    Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
    Pay attention to use date('now') function to get the current date, if the question involves "today".
    
    Use the following format:
    
    Question: Question here
    SQLQuery: SQL Query to run
    SQLResult: Result of the SQLQuery
    Answer: Final answer here
    
    """
    
    PROMPT_SUFFIX = """Only use the following tables:
    {table_info}
    
    Question: {input}"""
    
    • templateの値

      You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
      Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
      Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
      Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
      Pay attention to use date('now') function to get the current date, if the question involves "today".
      
      Use the following format:
      
      Question: Question here
      SQLQuery: SQL Query to run
      SQLResult: Result of the SQLQuery
      Answer: Final answer here
      
      Only use the following tables:
      {table_info}
      
      Question: {input}
      
      • 日本語訳

        あなたはSQLiteの専門家です。入力された質問に対して、まず構文的に正しいSQLiteクエリを作成し、そのクエリの結果を見て、入力された質問に対する答えを返してください。
        ユーザーが質問内で取得する例の数を指定しない限り、SQLiteのLIMIT句を使用して最大{top_k}結果をクエリします。データベース内の最も有益なデータを返すように結果を並べ替えることができます。
        決してテーブルから全ての列をクエリしないでください。質問に答えるために必要な列だけをクエリする必要があります。各列名を二重引用符(")で囲んで、それらを区切られた識別子として示してください。
        以下のテーブルで見える列名のみを使用するように注意してください。存在しない列をクエリしないように注意してください。また、どの列がどのテーブルにあるかにも注意してください。
        質問が「今日」を含む場合は、現在の日付を取得するためにdate('now')関数を使用するように注意してください。
        
        次の形式を使用してください:
        
        Question: ここに質問
        SQLQuery: 実行するSQLクエリ
        SQLResult: SQLクエリの結果
        Answer: ここに最終回答
        
        以下のテーブルのみを使用してください:
        {table_info}
        
        Question: {input}
        
  2. 回答内に\nSQLResult:が出現した時点で、以降の文章を切り捨てます。
    {input}の箇所には「questionキーの値 + \nSQLQuery: 」が入力されるため、
    対応するSQL文のみが出力されることを期待しています。

  3. db.get_table_infoは以下のような各テーブルの定義と3行のサンプルを出力します。

    CREATE TABLE "Album" (
    	"AlbumId" INTEGER NOT NULL, 
    	"Title" NVARCHAR(160) NOT NULL, 
    	"ArtistId" INTEGER NOT NULL, 
    	PRIMARY KEY ("AlbumId"), 
    	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
    )
    
    /*
    3 rows from Album table:
    AlbumId	Title	ArtistId
    1	For Those About To Rock We Salute You	1
    2	Balls to the Wall	2
    3	Restless and Wild	2
    */
    
  4. 問い合わせ文に行数の指定がない場合、デフォルトでは最大5行の結果の取得をLLMに指示します。
    この行数を変更したい場合はkオプションで指定します。

  5. 生成されるチェインは、invokeメソッドで入力する辞書のtable_names_to_useキーに
    使用するテーブル名のリストを指定することが出来ます。
    (指定すると{table_info}が短くなり文字数の節約になります。)

注意
table_names_to_useキーを指定した場合でも、LLMが問い合わせを元にテーブル名を推測して、
指定したテーブル以外の情報を取得することが可能です。
権限調整の代わりとしてtable_names_to_useキーを使用することは適切ではないため注意してください。

データベースに対して自然言語で問い合わせを行い、自然言語で回答を得るコードをまとめておきます。

import os
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

# APIキーを環境変数に設定
with open(".openai") as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()

# 接続先のデータベースを指定
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# LLM設定
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# SQL文を実行するツール
execute_query = QuerySQLDataBaseTool(db=db)

# SQL文を生成するチェイン
write_query = create_sql_query_chain(llm, db)

# 自然言語で回答を生成するためのプロンプトテンプレート
answer_prompt = PromptTemplate.from_template(
    """次のユーザーの質問、対応するSQLクエリ、およびSQL結果をもとに、ユーザーの質問に答えてください。

Question: {question}
SQLQuery: {query}
SQLResult: {result}
Answer: """
)

# 上記プロンプトを用いて回答を生成するチェイン
answer = answer_prompt | llm | StrOutputParser()

# write_queryチェインの実行結果を`query`キーに保存し、
# `query`キーの値をexecute_queryチェインに渡した結果を`result`キーに保存して、
# その辞書をanswerチェインに渡した結果を出力するチェイン
chain = (
    RunnablePassthrough.assign(
        query=write_query
    ).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

# 問い合わせを実行
chain.invoke({"question": "従業員は何人いますか"})

関連記事