Text-to-SQL with LangChain (2) - Agent

Series - Text-to-SQL with LangChain

Continuing from last post, we’ll explore the implementation of ‘Text-to-SQL’, where you can query database information in natural language and get answers, based on LangChain’s Quickstart documentation.

In this post, we’ll look into the operation of an agent created using the create_sql_agent function.

  • In this article, we’re using LangChain version 0.2.0.

    $ pip list|grep langchain
    langchain                0.2.0
    langchain-community      0.2.0
    langchain-core           0.2.0
    langchain-openai         0.1.7
    langchain-text-splitters 0.2.0
    

Prepare the environment similarly to last post.

import os

from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI

# Set API key as an environment variable
with open(".openai") as f:
    os.environ["OPENAI_API_KEY"] = f.read().strip()

# Specify the database to connect to
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# LLM settings
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

We’ll create an agent to perform Text-to-SQL using the create_sql_agent function.
According to the Quickstart, agents have several advantages over chains:

  1. They retry if an error occurs when executing the generated query.
  2. They can answer questions that require multiple query executions.
  3. They consider only the necessary table schemas, saving on token usage.

Create the agent with the following code:

from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)
  • Execution Log

    > Entering new SQL Agent Executor chain...
    
    Invoking: `sql_db_list_tables` with `{}`
    
    
    Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
    Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`
    
    
    
    CREATE TABLE "Customer" (
    	"CustomerId" INTEGER NOT NULL, 
    	"FirstName" NVARCHAR(40) NOT NULL, 
    	"LastName" NVARCHAR(20) NOT NULL, 
    	"Company" NVARCHAR(80), 
    	"Address" NVARCHAR(70), 
    	"City" NVARCHAR(40), 
    	"State" NVARCHAR(40), 
    	"Country" NVARCHAR(40), 
    	"PostalCode" NVARCHAR(10), 
    	"Phone" NVARCHAR(24), 
    	"Fax" NVARCHAR(24), 
    	"Email" NVARCHAR(60) NOT NULL, 
    	"SupportRepId" INTEGER, 
    	PRIMARY KEY ("CustomerId"), 
    	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
    )
    
    /*
    3 rows from Customer table:
    CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
    1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	[email protected]	3
    2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	[email protected]	5
    3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	[email protected]	3
    */
    
    
    CREATE TABLE "Invoice" (
    	"InvoiceId" INTEGER NOT NULL, 
    	"CustomerId" INTEGER NOT NULL, 
    	"InvoiceDate" DATETIME NOT NULL, 
    	"BillingAddress" NVARCHAR(70), 
    	"BillingCity" NVARCHAR(40), 
    	"BillingState" NVARCHAR(40), 
    	"BillingCountry" NVARCHAR(40), 
    	"BillingPostalCode" NVARCHAR(10), 
    	"Total" NUMERIC(10, 2) NOT NULL, 
    	PRIMARY KEY ("InvoiceId"), 
    	FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
    )
    
    /*
    3 rows from Invoice table:
    InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
    1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
    2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
    3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
    */
    
    
    CREATE TABLE "InvoiceLine" (
    	"InvoiceLineId" INTEGER NOT NULL, 
    	"InvoiceId" INTEGER NOT NULL, 
    	"TrackId" INTEGER NOT NULL, 
    	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
    	"Quantity" INTEGER NOT NULL, 
    	PRIMARY KEY ("InvoiceLineId"), 
    	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    	FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
    )
    
    /*
    3 rows from InvoiceLine table:
    InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
    1	1	2	0.99	1
    2	1	4	0.99	1
    3	2	6	0.99	1
    */
    Invoking: `sql_db_query` with `{'query': 'SELECT BillingCountry AS Country, SUM(Total) AS TotalSales FROM Invoice GROUP BY BillingCountry ORDER BY TotalSales DESC;'}`
    responded: To find the total sales per country, we need to sum the total amount from the invoices for each country. Here is the query to achieve this:
    
      ```sql
      SELECT BillingCountry AS Country, SUM(Total) AS TotalSales
      FROM Invoice
      GROUP BY BillingCountry
      ORDER BY TotalSales DESC;
      ```
    
    By running this query, we can determine which country's customers spent the most. Let me execute the query to provide you with the answer.
    
    [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]The total sales per country are as follows:
    
    1. USA: $523.06
    2. Canada: $303.96
    3. France: $195.10
    4. Brazil: $190.10
    5. Germany: $156.48
    
    Therefore, customers from the USA spent the most in total.
    
    > Finished chain.
    
  • Execution Result

    {'input': "List the total sales per country. Which country's customers spent the most?",
     'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nTherefore, customers from the USA spent the most in total.'}
    
agent_executor.invoke({"input": "Describe the playlisttrack table"})
  • Execution Log

    > Entering new SQL Agent Executor chain...
    
    Invoking: `sql_db_list_tables` with `{}`
    
    
    Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
    Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}`
    
    
    
    CREATE TABLE "PlaylistTrack" (
    	"PlaylistId" INTEGER NOT NULL, 
    	"TrackId" INTEGER NOT NULL, 
    	PRIMARY KEY ("PlaylistId", "TrackId"), 
    	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
    )
    
    /*
    3 rows from PlaylistTrack table:
    PlaylistId	TrackId
    1	3402
    1	3389
    1	3390
    */The `PlaylistTrack` table has the following columns:
    - PlaylistId (INTEGER, NOT NULL)
    - TrackId (INTEGER, NOT NULL)
    
    It has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column.
    
    Here are 3 sample rows from the `PlaylistTrack` table:
    1. PlaylistId: 1, TrackId: 3402
    2. PlaylistId: 1, TrackId: 3389
    3. PlaylistId: 1, TrackId: 3390
    
    > Finished chain.
    
  • Execution Result

    {'input': 'Describe the playlisttrack table',
     'output': "The `PlaylistTrack` table has the following columns:\n- PlaylistId (INTEGER, NOT NULL)\n- TrackId (INTEGER, NOT NULL)\n\nIt has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column.\n\nHere are 3 sample rows from the `PlaylistTrack` table:\n1. PlaylistId: 1, TrackId: 3402\n2. PlaylistId: 1, TrackId: 3389\n3. PlaylistId: 1, TrackId: 3390"}
    

Related Content