LangChainの文章要約を行うコードを読む (1) - Stuff

今回はLangChainのドキュメントSummarizationで紹介されている、
文章を要約するチェインの仕組みについて詳しく見ていきます。
-
本記事では、LangChainのバージョン
0.1.17を使用します。$ pip list|grep langchain langchain 0.1.17 langchain-community 0.0.37 langchain-core 0.1.52 langchain-openai 0.1.6 langchain-text-splitters 0.0.1 langchainhub 0.1.15
1. 文章の要約を試す
まずは、ドキュメントのQuickstartで紹介されている、文章を要約するコードを見てみましょう。
別記事の『LangChainのQuickstartを読む』で行ったように、
OpenAIのAPIキーを.openaiファイルに保存しておき、環境変数に読み込むようにします。
import os
from langchain.chains.summarize import load_summarize_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import ChatOpenAI
# APIキーを設定
with open('.openai') as f:
os.environ['OPENAI_API_KEY'] = f.read().strip()
# ドキュメント読み込み
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
docs = loader.load()
# LLM 設定
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-1106")
# 要約を行うチェインを作成
chain = load_summarize_chain(llm, chain_type="stuff")
# 要約を実施
chain.run(docs)
実行すると以下のような要約が出力されました。
'The article discusses the concept of LLM-powered autonomous agents, with a focus on the components of planning, memory, and tool use. It includes case studies and proof-of-concept examples, as well as challenges and references to related research. The author also provides a citation for the article.'
- 関数
load_summarize_chainを使用して、文章要約を行うチェインを作成しています。- オプションの
chain_typeにはstuffが指定されていますが、
他にもmap_reduceとrefineが指定できます。
- オプションの
以下では、load_summarize_chainの動作の詳細と各オプションについて詳しく見ていきます。
2. load_summarize_chainのコードを確認する
load_summarize_chainは以下のように定義されています(リンク)。
def load_summarize_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load summarizing chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
Returns:
A chain to use for summarizing.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
-
パラメータ
chain_typeによって実行する関数を変えています。chain_type 関数 stuff _load_stuff_chainmap_reduce _load_map_reduce_chainrefine _load_refine_chain -
いずれの関数にもパラメータとして
llm, verbose=verbose, **kwargsを渡しています。
3. オプション1: Stuff
まずは、chain_type="stuff"の場合のコードについて見ていきます。
“stuff"の場合、以下コードで文章要約を行うチェインが生成されます。
_load_stuff_chain(llm, verbose=verbose, **kwargs)
3.1. _load_stuff_chain
_load_stuff_chainは以下のように定義されています(リンク)。
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
# TODO: document prompt
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)
-
クラス
LLMChainを使用して、指定したLLM(llm)とプロンプトテンプレート(prompt)を用いて
回答を生成するチェインllm_chainを作成しています。-
promptはデフォルトではstuff_prompt.PROMPTが指定されています。
このプロンプト(テンプレート)は以下のように定義されています(リンク)。from langchain_core.prompts import PromptTemplate prompt_template = """Write a concise summary of the following: "{text}" CONCISE SUMMARY:""" PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
-
-
StuffDocumentsChainにllm_chainを渡して要約を行うチェインを作成しています。
3.2. StuffDocumentsChain
次に、StuffDocumentsChainのコードを確認します。
StuffDocumentsChainは以下のように定義されています(リンク)。
class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
llm_chain: LLMChain
"""LLM chain which is called with the formatted document string,
along with any other inputs."""
document_prompt: BasePromptTemplate = Field(
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
)
"""Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
...
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and then join all the documents together into one input with name
`self.document_variable_name`. Also pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs
...
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
...
-
docstringを読むと、以下のように書かれています。
- 入力された各ドキュメントをプロンプト
document_promptを使用して整形し、
セパレータdocument_separatorを使用して一つのドキュメントにまとめる。 - 出来たドキュメントを
llm_chainに渡して回答を生成する。
- 入力された各ドキュメントをプロンプト
-
要約はメソッド
combine_docsを用いて生成されます(詳細はおまけを参照)。
コードを確認すると、docstringの説明通りの動作をしていることがわかります。 -
document_promptはデフォルトではDEFAULT_DOCUMENT_PROMPTが設定されています。document_prompt: BasePromptTemplate = Field( default_factory=lambda: DEFAULT_DOCUMENT_PROMPT )-
DEFAULT_DOCUMENT_PROMPTは以下のように定義されており、
単純にpage_content(各ドキュメントの内容)を代入してそのまま返すテンプレートでした。PromptTemplate.from_template("{page_content}")
-
-
document_separatorはデフォルトでは"\n\n"が指定されています。
3.3. まとめ
load_summarize_chainのchain_type="stuff"は以下のように動作することがわかりました:
-
出力されるチェインは、ドキュメントのリストを整形して結合後、LLMに丸ごと渡して要約を生成する。
- そのため、指定したLLMで扱える長さの上限を超える文章を要約することは出来ない。
-
promptオプションで、要約に使用するテンプレートを指定できる。
(元の文章を入力する箇所には{text}と記載する。)-
指定しない場合は以下のテンプレートが使用される。
Write a concise summary of the following: "{text}" CONCISE SUMMARY:
-
-
document_promptオプションで、各ドキュメントを整形する際のテンプレートを指定できる。- 指定しない場合はドキュメントをそのまま出力するテンプレートが使用される。
-
document_separatorオプションで、整形後の各ドキュメントを結合するセパレータを指定できる。- 指定しない場合は
"\n\n"が使用される。
- 指定しない場合は
おまけ
本節では、より詳しく知りたい方向けの情報を掲載します。
チェインのinvokeメソッド
最初にチェインのinvokeメソッドのコードを確認します。
クラスChainのinvokeメソッドは以下のように定義されています。
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
models, document retrievers, other chains, etc., and provide a simple interface
to this sequence.
The Chain interface makes it easy to create apps that are:
- Stateful: add Memory to any Chain to give it state,
- Observable: pass Callbacks to a Chain to execute additional functionality,
like logging, outside the main sequence of component calls,
- Composable: the Chain API is flexible enough that it is easy to combine
Chains with other components, including other Chains.
The main methods exposed by chains are:
- `__call__`: Chains are callable. The `__call__` method is the primary way to
execute a Chain. This takes inputs as a dictionary and returns a
dictionary output.
- `run`: A convenience method that takes inputs as args/kwargs and returns the
output as a string or object. This method can only be used for a subset of
chains and cannot return as rich of an output as `__call__`.
"""
...
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
inputs = self.prep_inputs(input)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
invokeメソッドは、大まかには以下のような処理が行われていることがわかります。
prep_inputsメソッドを使用して入力データの前処理を行う。- その値を
_callメソッドに入力して実行する。 prep_outputsメソッドで後処理を行い出力する。
要約を行うチェインのベースクラス BaseCombineDocumentsChain
StuffDocumentsChainなどの要約を行うチェインはBaseCombineDocumentsChainクラスを継承しています。
このクラスの_callメソッドは以下のように定義されています。
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents.
Subclasses of this chain deal with combining documents in a variety of
ways. This base class exists to add some uniformity in the interface these types
of chains should expose. Namely, they expect an input key related to the documents
to use (default `input_documents`), and then also expose a method to calculate
the length of a prompt from documents (useful for outside callers to use to
determine whether it's safe to pass a list of documents into this chain or whether
that will be longer than the context length).
"""
...
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
- 入力されたドキュメントリストを
combine_docsメソッドで処理して出力していることがわかります。
PydanticのBaseModel
クラスChainは以下のような形でpydanticのBaseModelを継承しています。
Chain <- RunnableSerializable <- Serializable <- BaseModel
そのため、StuffDocumentsChain等でインスタンス作成時に指定したパラメータ(llm_chain等)は同名のクラス変数に代入されます。
詳しくはpydanticのドキュメントを参照してください。