Text-to-SQL with LangChain (2) - Agent
Series - Text-to-SQL with LangChain
Contents
Continuing from last post, we’ll explore the implementation of ‘Text-to-SQL’, where you can query database information in natural language and get answers, based on LangChain’s Quickstart documentation.
In this post, we’ll look into the operation of an agent created using the create_sql_agent
function.
-
In this article, we’re using LangChain version
0.2.0
.$ pip list|grep langchain langchain 0.2.0 langchain-community 0.2.0 langchain-core 0.2.0 langchain-openai 0.1.7 langchain-text-splitters 0.2.0
Setup
Prepare the environment similarly to last post.
import os
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
# Set API key as an environment variable
with open(".openai") as f:
os.environ["OPENAI_API_KEY"] = f.read().strip()
# Specify the database to connect to
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# LLM settings
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
3. Agent
We’ll create an agent to perform Text-to-SQL using the create_sql_agent
function.
According to the Quickstart, agents have several advantages over chains:
- They retry if an error occurs when executing the generated query.
- They can answer questions that require multiple query executions.
- They consider only the necessary table schemas, saving on token usage.
Create the agent with the following code:
from langchain_community.agent_toolkits import create_sql_agent
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
3.1. Example 1: Aggregation across Multiple Tables
agent_executor.invoke(
{
"input": "List the total sales per country. Which country's customers spent the most?"
}
)
-
Execution Log
> Entering new SQL Agent Executor chain... Invoking: `sql_db_list_tables` with `{}` Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}` CREATE TABLE "Customer" ( "CustomerId" INTEGER NOT NULL, "FirstName" NVARCHAR(40) NOT NULL, "LastName" NVARCHAR(20) NOT NULL, "Company" NVARCHAR(80), "Address" NVARCHAR(70), "City" NVARCHAR(40), "State" NVARCHAR(40), "Country" NVARCHAR(40), "PostalCode" NVARCHAR(10), "Phone" NVARCHAR(24), "Fax" NVARCHAR(24), "Email" NVARCHAR(60) NOT NULL, "SupportRepId" INTEGER, PRIMARY KEY ("CustomerId"), FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId") ) /* 3 rows from Customer table: CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId 1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 [email protected] 3 2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None [email protected] 5 3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None [email protected] 3 */ CREATE TABLE "Invoice" ( "InvoiceId" INTEGER NOT NULL, "CustomerId" INTEGER NOT NULL, "InvoiceDate" DATETIME NOT NULL, "BillingAddress" NVARCHAR(70), "BillingCity" NVARCHAR(40), "BillingState" NVARCHAR(40), "BillingCountry" NVARCHAR(40), "BillingPostalCode" NVARCHAR(10), "Total" NUMERIC(10, 2) NOT NULL, PRIMARY KEY ("InvoiceId"), FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId") ) /* 3 rows from Invoice table: InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total 1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98 2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96 3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94 */ CREATE TABLE "InvoiceLine" ( "InvoiceLineId" INTEGER NOT NULL, "InvoiceId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, "UnitPrice" NUMERIC(10, 2) NOT NULL, "Quantity" INTEGER NOT NULL, PRIMARY KEY ("InvoiceLineId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId") ) /* 3 rows from InvoiceLine table: InvoiceLineId InvoiceId TrackId UnitPrice Quantity 1 1 2 0.99 1 2 1 4 0.99 1 3 2 6 0.99 1 */ Invoking: `sql_db_query` with `{'query': 'SELECT BillingCountry AS Country, SUM(Total) AS TotalSales FROM Invoice GROUP BY BillingCountry ORDER BY TotalSales DESC;'}` responded: To find the total sales per country, we need to sum the total amount from the invoices for each country. Here is the query to achieve this: ```sql SELECT BillingCountry AS Country, SUM(Total) AS TotalSales FROM Invoice GROUP BY BillingCountry ORDER BY TotalSales DESC; ``` By running this query, we can determine which country's customers spent the most. Let me execute the query to provide you with the answer. [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]The total sales per country are as follows: 1. USA: $523.06 2. Canada: $303.96 3. France: $195.10 4. Brazil: $190.10 5. Germany: $156.48 Therefore, customers from the USA spent the most in total. > Finished chain.
-
Execution Result
{'input': "List the total sales per country. Which country's customers spent the most?", 'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nTherefore, customers from the USA spent the most in total.'}
3.2. Example 2: Describing Tables
agent_executor.invoke({"input": "Describe the playlisttrack table"})
-
Execution Log
> Entering new SQL Agent Executor chain... Invoking: `sql_db_list_tables` with `{}` Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}` CREATE TABLE "PlaylistTrack" ( "PlaylistId" INTEGER NOT NULL, "TrackId" INTEGER NOT NULL, PRIMARY KEY ("PlaylistId", "TrackId"), FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId") ) /* 3 rows from PlaylistTrack table: PlaylistId TrackId 1 3402 1 3389 1 3390 */The `PlaylistTrack` table has the following columns: - PlaylistId (INTEGER, NOT NULL) - TrackId (INTEGER, NOT NULL) It has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column. Here are 3 sample rows from the `PlaylistTrack` table: 1. PlaylistId: 1, TrackId: 3402 2. PlaylistId: 1, TrackId: 3389 3. PlaylistId: 1, TrackId: 3390 > Finished chain.
-
Execution Result
{'input': 'Describe the playlisttrack table', 'output': "The `PlaylistTrack` table has the following columns:\n- PlaylistId (INTEGER, NOT NULL)\n- TrackId (INTEGER, NOT NULL)\n\nIt has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column.\n\nHere are 3 sample rows from the `PlaylistTrack` table:\n1. PlaylistId: 1, TrackId: 3402\n2. PlaylistId: 1, TrackId: 3389\n3. PlaylistId: 1, TrackId: 3390"}