LangChainでText-to-SQLを試す (2) - Agent

シリーズ - LangChainでText-to-SQLを試す

前回に引き続き、LangChainのドキュメント『Q&A over SQL+CSV』のQuickstartを参考に、
データベースの情報を自然言語で問い合わせして回答を得る、“Text-to-SQL"の実装について見ていきます。

今回は、create_sql_agent関数を使用して作成するエージェントの動作について確認します。

  • 本記事では、LangChainのバージョン0.2.0を使用します。

    $ pip list|grep langchain
    langchain                0.2.0
    langchain-community      0.2.0
    langchain-core           0.2.0
    langchain-openai         0.1.7
    langchain-text-splitters 0.2.0
    

前回の準備と同様に環境を用意します。

import os

from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI

# APIキーを環境変数に設定
with open(".openai") as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()

# 接続先のデータベースを指定
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# LLM設定
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

関数create_sql_agentを使用して、Text-to-SQLを行うエージェントを作成します。
Quickstartによると、エージェントはチェインと比較して以下ようなの利点があるようです:

  1. 生成したクエリの実行時にエラーが出た場合に再試行を実施する。
  2. 複数のクエリ実行を必要とする質問に回答が可能。
  3. 必要なテーブルのスキーマのみを考慮するため、トークン数の節約になる。

以下コードでエージェントを作成します。

from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
agent_executor.invoke({
    "input": "国毎の合計売上高のランキングを出力してください。どの国の顧客が一番お金を使っているでしょうか?"
})
  • 実行ログ

    > Entering new SQL Agent Executor chain...
    
    Invoking: `sql_db_list_tables` with `{}`
    
    
    Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
    Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`
    
    
    
    CREATE TABLE "Customer" (
    	"CustomerId" INTEGER NOT NULL, 
    	"FirstName" NVARCHAR(40) NOT NULL, 
    	"LastName" NVARCHAR(20) NOT NULL, 
    	"Company" NVARCHAR(80), 
    	"Address" NVARCHAR(70), 
    	"City" NVARCHAR(40), 
    	"State" NVARCHAR(40), 
    	"Country" NVARCHAR(40), 
    	"PostalCode" NVARCHAR(10), 
    	"Phone" NVARCHAR(24), 
    	"Fax" NVARCHAR(24), 
    	"Email" NVARCHAR(60) NOT NULL, 
    	"SupportRepId" INTEGER, 
    	PRIMARY KEY ("CustomerId"), 
    	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
    )
    
    /*
    3 rows from Customer table:
    CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
    1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	[email protected]	3
    2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	[email protected]	5
    3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	[email protected]	3
    */
    
    
    CREATE TABLE "Invoice" (
    	"InvoiceId" INTEGER NOT NULL, 
    	"CustomerId" INTEGER NOT NULL, 
    	"InvoiceDate" DATETIME NOT NULL, 
    	"BillingAddress" NVARCHAR(70), 
    	"BillingCity" NVARCHAR(40), 
    	"BillingState" NVARCHAR(40), 
    	"BillingCountry" NVARCHAR(40), 
    	"BillingPostalCode" NVARCHAR(10), 
    	"Total" NUMERIC(10, 2) NOT NULL, 
    	PRIMARY KEY ("InvoiceId"), 
    	FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
    )
    
    /*
    3 rows from Invoice table:
    InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
    1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
    2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
    3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
    */
    
    
    CREATE TABLE "InvoiceLine" (
    	"InvoiceLineId" INTEGER NOT NULL, 
    	"InvoiceId" INTEGER NOT NULL, 
    	"TrackId" INTEGER NOT NULL, 
    	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
    	"Quantity" INTEGER NOT NULL, 
    	PRIMARY KEY ("InvoiceLineId"), 
    	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    	FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
    )
    
    /*
    3 rows from InvoiceLine table:
    InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
    1	1	2	0.99	1
    2	1	4	0.99	1
    3	2	6	0.99	1
    */
    Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'}`
    
    
    [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]合計売上高のランキングは以下の通りです:
    1. USA: $523.06
    2. Canada: $303.96
    3. France: $195.10
    4. Brazil: $190.10
    5. Germany: $156.48
    
    USAの顧客が一番お金を使っています。
    
    > Finished chain.
    
  • 実行結果

    {'input': '国毎の合計売上高のランキングを出力してください。どの国の顧客が一番お金を使っているでしょうか?',
     'output': '合計売上高のランキングは以下の通りです:\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nUSAの顧客が一番お金を使っています。'}
    
agent_executor.invoke({"input": "プレイリストトラックテーブルについて説明してください。"})
  • 実行ログ

    > Entering new SQL Agent Executor chain...
    
    Invoking: `sql_db_list_tables` with `{}`
    
    
    Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
    Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}`
    
    
    
    CREATE TABLE "PlaylistTrack" (
    	"PlaylistId" INTEGER NOT NULL, 
    	"TrackId" INTEGER NOT NULL, 
    	PRIMARY KEY ("PlaylistId", "TrackId"), 
    	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
    )
    
    /*
    3 rows from PlaylistTrack table:
    PlaylistId	TrackId
    1	3402
    1	3389
    1	3390
    */プレイリストトラックテーブル(PlaylistTrack)には以下のカラムが含まれています:
    - PlaylistId: プレイリストを識別するためのID(INTEGER)
    - TrackId: トラックを識別するためのID(INTEGER)
    
    また、プレイリストトラックテーブルにはプレイリストとトラックの関連性を示すための主キーが設定されています。
    
    プレイリストトラックテーブルからの例として、以下のようなデータが含まれています:
    - PlaylistId: 1, TrackId: 3402
    - PlaylistId: 1, TrackId: 3389
    - PlaylistId: 1, TrackId: 3390
    
    これは、プレイリストIDが1のプレイリストにそれぞれ異なるトラックが関連付けられていることを示しています。
    
    > Finished chain.
    
  • 実行結果

    {'input': 'プレイリストトラックテーブルについて説明してください。',
     'output': 'プレイリストトラックテーブル(PlaylistTrack)には以下のカラムが含まれています:\n- PlaylistId: プレイリストを識別するためのID(INTEGER)\n- TrackId: トラックを識別するためのID(INTEGER)\n\nまた、プレイリストトラックテーブルにはプレイリストとトラックの関連性を示すための主キーが設定されています。\n\nプレイリストトラックテーブルからの例として、以下のようなデータが含まれています:\n- PlaylistId: 1, TrackId: 3402\n- PlaylistId: 1, TrackId: 3389\n- PlaylistId: 1, TrackId: 3390\n\nこれは、プレイリストIDが1のプレイリストにそれぞれ異なるトラックが関連付けられていることを示しています。'}
    

関連記事