LangChainでText-to-SQLを試す (2) - Agent
シリーズ - LangChainでText-to-SQLを試す
コンテンツ
前回に引き続き、LangChainのドキュメント『Q&A over SQL+CSV』のQuickstartを参考に、
データベースの情報を自然言語で問い合わせして回答を得る、“Text-to-SQL"の実装について見ていきます。
今回は、create_sql_agent
関数を使用して作成するエージェントの動作について確認します。
-
本記事では、LangChainのバージョン
0.2.0
を使用します。$ pip list|grep langchain langchain 0.2.0 langchain-community 0.2.0 langchain-core 0.2.0 langchain-openai 0.1.7 langchain-text-splitters 0.2.0
準備
前回の準備と同様に環境を用意します。
import os
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
# APIキーを環境変数に設定
with open(".openai") as f:
os.environ["OPENAI_API_KEY"] = f.read().strip()
# 接続先のデータベースを指定
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# LLM設定
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
3. エージェント
関数create_sql_agent
を使用して、Text-to-SQLを行うエージェントを作成します。
Quickstartによると、エージェントはチェインと比較して以下ようなの利点があるようです:
- 生成したクエリの実行時にエラーが出た場合に再試行を実施する。
- 複数のクエリ実行を必要とする質問に回答が可能。
- 必要なテーブルのスキーマのみを考慮するため、トークン数の節約になる。
以下コードでエージェントを作成します。
from langchain_community.agent_toolkits import create_sql_agent
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
3.1. 使用例1: 複数テーブルにまたがる集計
agent_executor.invoke({
"input": "国毎の合計売上高のランキングを出力してください。どの国の顧客が一番お金を使っているでしょうか?"
})
-
実行ログ
> Entering new SQL Agent Executor chain... Invoking: `sql_db_list_tables` with `{}` Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}` CREATE TABLE "Customer" ( "CustomerId" INTEGER NOT NULL, "FirstName" NVARCHAR(40) NOT NULL, "LastName" NVARCHAR(20) NOT NULL, "Company" NVARCHAR(80), "Address" NVARCHAR(70), "City" NVARCHAR(40), "State" NVARCHAR(40), "Country" NVARCHAR(40), "PostalCode" NVARCHAR(10), "Phone" NVARCHAR(24), "Fax" NVARCHAR(24), "Email" NVARCHAR(60) NOT NULL, "SupportRepId" INTEGER, PRIMARY KEY ("CustomerId"), FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId") ) /* 3 rows from Customer table: CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId 1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 [email protected] 3 2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None [email protected] 5 3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None [email protected] 3 */ CREATE TABLE "Invoice" ( "InvoiceId" INTEGER NOT NULL, "CustomerId" INTEGER NOT NULL, "InvoiceDate" DATETIME NOT NULL, "BillingAddress" NVARCHAR(70), "BillingCity" NVARCHAR(40), "BillingState" NVARCHAR(40), "BillingCountry" NVARCHAR(40), "BillingPostalCode" NVARCHAR(10), "Total" NUMERIC(10, 2) NOT NULL, PRIMARY KEY ("InvoiceId"), FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId") ) /* 3 rows from Invoice table: InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total 1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98 2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96 3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94 */ CREATE TABLE "InvoiceLine" ( "InvoiceLineId" INTEGER NOT NULL, "InvoiceId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, "UnitPrice" NUMERIC(10, 2) NOT NULL, "Quantity" INTEGER NOT NULL, PRIMARY KEY ("InvoiceLineId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId") ) /* 3 rows from InvoiceLine table: InvoiceLineId InvoiceId TrackId UnitPrice Quantity 1 1 2 0.99 1 2 1 4 0.99 1 3 2 6 0.99 1 */ Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'}` [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]合計売上高のランキングは以下の通りです: 1. USA: $523.06 2. Canada: $303.96 3. France: $195.10 4. Brazil: $190.10 5. Germany: $156.48 USAの顧客が一番お金を使っています。 > Finished chain.
-
実行結果
{'input': '国毎の合計売上高のランキングを出力してください。どの国の顧客が一番お金を使っているでしょうか?', 'output': '合計売上高のランキングは以下の通りです:\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nUSAの顧客が一番お金を使っています。'}
3.2. 使用例2: テーブルの説明
agent_executor.invoke({"input": "プレイリストトラックテーブルについて説明してください。"})
-
実行ログ
> Entering new SQL Agent Executor chain... Invoking: `sql_db_list_tables` with `{}` Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}` CREATE TABLE "PlaylistTrack" ( "PlaylistId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, PRIMARY KEY ("PlaylistId", "TrackId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId") ) /* 3 rows from PlaylistTrack table: PlaylistId TrackId 1 3402 1 3389 1 3390 */プレイリストトラックテーブル(PlaylistTrack)には以下のカラムが含まれています: - PlaylistId: プレイリストを識別するためのID(INTEGER) - TrackId: トラックを識別するためのID(INTEGER) また、プレイリストトラックテーブルにはプレイリストとトラックの関連性を示すための主キーが設定されています。 プレイリストトラックテーブルからの例として、以下のようなデータが含まれています: - PlaylistId: 1, TrackId: 3402 - PlaylistId: 1, TrackId: 3389 - PlaylistId: 1, TrackId: 3390 これは、プレイリストIDが1のプレイリストにそれぞれ異なるトラックが関連付けられていることを示しています。 > Finished chain.
-
実行結果
{'input': 'プレイリストトラックテーブルについて説明してください。', 'output': 'プレイリストトラックテーブル(PlaylistTrack)には以下のカラムが含まれています:\n- PlaylistId: プレイリストを識別するためのID(INTEGER)\n- TrackId: トラックを識別するためのID(INTEGER)\n\nまた、プレイリストトラックテーブルにはプレイリストとトラックの関連性を示すための主キーが設定されています。\n\nプレイリストトラックテーブルからの例として、以下のようなデータが含まれています:\n- PlaylistId: 1, TrackId: 3402\n- PlaylistId: 1, TrackId: 3389\n- PlaylistId: 1, TrackId: 3390\n\nこれは、プレイリストIDが1のプレイリストにそれぞれ異なるトラックが関連付けられていることを示しています。'}