LangChainの文章要約を行うコードを読む (1) - Stuff

シリーズ - Langchainの文章要約を行うコードを読む

今回はLangChainのドキュメントSummarizationで紹介されている、
文章を要約するチェインの仕組みについて詳しく見ていきます。

  • 本記事では、LangChainのバージョン 0.1.17 を使用します。

    $ pip list|grep langchain
    langchain                0.1.17
    langchain-community      0.0.37
    langchain-core           0.1.52
    langchain-openai         0.1.6
    langchain-text-splitters 0.0.1
    langchainhub             0.1.15
    

まずは、ドキュメントのQuickstartで紹介されている、文章を要約するコードを見てみましょう。

別記事の『LangChainのQuickstartを読む』で行ったように、
OpenAIのAPIキーを.openaiファイルに保存しておき、環境変数に読み込むようにします。

import os
from langchain.chains.summarize import load_summarize_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import ChatOpenAI

# APIキーを設定
with open('.openai') as f:
    os.environ['OPENAI_API_KEY'] = f.read().strip()

# ドキュメント読み込み
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
docs = loader.load()

# LLM 設定
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-1106")

# 要約を行うチェインを作成
chain = load_summarize_chain(llm, chain_type="stuff")

# 要約を実施
chain.run(docs)

実行すると以下のような要約が出力されました。

'The article discusses the concept of LLM-powered autonomous agents, with a focus on the components of planning, memory, and tool use. It includes case studies and proof-of-concept examples, as well as challenges and references to related research. The author also provides a citation for the article.'
  • 関数load_summarize_chainを使用して、文章要約を行うチェインを作成しています。
    • オプションのchain_typeにはstuffが指定されていますが、
      他にもmap_reducerefineが指定できます。

以下では、load_summarize_chainの動作の詳細と各オプションについて詳しく見ていきます。

load_summarize_chainは以下のように定義されています(リンク)。

def load_summarize_chain(
    llm: BaseLanguageModel,
    chain_type: str = "stuff",
    verbose: Optional[bool] = None,
    **kwargs: Any,
) -> BaseCombineDocumentsChain:
    """Load summarizing chain.

    Args:
        llm: Language Model to use in the chain.
        chain_type: Type of document combining chain to use. Should be one of "stuff",
            "map_reduce", and "refine".
        verbose: Whether chains should be run in verbose mode or not. Note that this
            applies to all chains that make up the final chain.

    Returns:
        A chain to use for summarizing.
    """
    loader_mapping: Mapping[str, LoadingCallable] = {
        "stuff": _load_stuff_chain,
        "map_reduce": _load_map_reduce_chain,
        "refine": _load_refine_chain,
    }
    if chain_type not in loader_mapping:
        raise ValueError(
            f"Got unsupported chain type: {chain_type}. "
            f"Should be one of {loader_mapping.keys()}"
        )
    return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
  • パラメータchain_typeによって実行する関数を変えています。

    chain_type 関数
    stuff _load_stuff_chain
    map_reduce _load_map_reduce_chain
    refine _load_refine_chain
  • いずれの関数にもパラメータとして llm, verbose=verbose, **kwargs を渡しています。

まずは、chain_type="stuff"の場合のコードについて見ていきます。
“stuff"の場合、以下コードで文章要約を行うチェインが生成されます。

_load_stuff_chain(llm, verbose=verbose, **kwargs)

_load_stuff_chainは以下のように定義されています(リンク)。

def _load_stuff_chain(
    llm: BaseLanguageModel,
    prompt: BasePromptTemplate = stuff_prompt.PROMPT,
    document_variable_name: str = "text",
    verbose: Optional[bool] = None,
    **kwargs: Any,
) -> StuffDocumentsChain:
    llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)  # type: ignore[arg-type]
    # TODO: document prompt
    return StuffDocumentsChain(
        llm_chain=llm_chain,
        document_variable_name=document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )
  1. クラスLLMChainを使用して、指定したLLM(llm)とプロンプトテンプレート(prompt)を用いて
    回答を生成するチェインllm_chainを作成しています。

    • promptはデフォルトではstuff_prompt.PROMPTが指定されています。
      このプロンプト(テンプレート)は以下のように定義されています(リンク)。

      from langchain_core.prompts import PromptTemplate
      
      prompt_template = """Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:"""
      PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
      
  2. StuffDocumentsChainllm_chainを渡して要約を行うチェインを作成しています。

次に、StuffDocumentsChainのコードを確認します。

StuffDocumentsChainは以下のように定義されています(リンク)。

class StuffDocumentsChain(BaseCombineDocumentsChain):
    """Chain that combines documents by stuffing into context.

    This chain takes a list of documents and first combines them into a single string.
    It does this by formatting each document into a string with the `document_prompt`
    and then joining them together with `document_separator`. It then adds that new
    string to the inputs with the variable name set by `document_variable_name`.
    Those inputs are then passed to the `llm_chain`.

    Example:
        .. code-block:: python

            from langchain.chains import StuffDocumentsChain, LLMChain
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            llm_chain = LLMChain(llm=llm, prompt=prompt)
            chain = StuffDocumentsChain(
                llm_chain=llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name
            )
    """

    llm_chain: LLMChain
    """LLM chain which is called with the formatted document string,
    along with any other inputs."""
    document_prompt: BasePromptTemplate = Field(
        default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
    )
    """Prompt to use to format each document, gets passed to `format_document`."""
    document_variable_name: str
    """The variable name in the llm_chain to put the documents in.
    If only one variable in the llm_chain, this need not be provided."""
    document_separator: str = "\n\n"
    """The string with which to join the formatted documents"""
    ...
    def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
        """Construct inputs from kwargs and docs.

        Format and then join all the documents together into one input with name
        `self.document_variable_name`. Also pluck any additional variables
        from **kwargs.

        Args:
            docs: List of documents to format and then join into single input
            **kwargs: additional inputs to chain, will pluck any other required
                arguments from here.

        Returns:
            dictionary of inputs to LLMChain
        """
        # Format each document according to the prompt
        doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
        # Join the documents together to put them in the prompt.
        inputs = {
            k: v
            for k, v in kwargs.items()
            if k in self.llm_chain.prompt.input_variables
        }
        inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
        return inputs

    ...

    def combine_docs(
        self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Stuff all documents into one prompt and pass to LLM.

        Args:
            docs: List of documents to join together into one variable
            callbacks: Optional callbacks to pass along
            **kwargs: additional parameters to use to get inputs to LLMChain.

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._get_inputs(docs, **kwargs)
        # Call predict on the LLM.
        return self.llm_chain.predict(callbacks=callbacks, **inputs), {}

    ...
  • docstringを読むと、以下のように書かれています。

    1. 入力された各ドキュメントをプロンプトdocument_promptを使用して整形し、
      セパレータdocument_separatorを使用して一つのドキュメントにまとめる。
    2. 出来たドキュメントをllm_chainに渡して回答を生成する。
  • 要約はメソッドcombine_docsを用いて生成されます(詳細はおまけを参照)。
    コードを確認すると、docstringの説明通りの動作をしていることがわかります。

  • document_promptはデフォルトではDEFAULT_DOCUMENT_PROMPTが設定されています。

    document_prompt: BasePromptTemplate = Field(
        default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
    )
    
    • DEFAULT_DOCUMENT_PROMPTは以下のように定義されており
      単純にpage_content(各ドキュメントの内容)を代入してそのまま返すテンプレートでした。

      PromptTemplate.from_template("{page_content}")
      
  • document_separatorはデフォルトでは"\n\n"が指定されています。

load_summarize_chainchain_type="stuff"は以下のように動作することがわかりました:

  1. 出力されるチェインは、ドキュメントのリストを整形して結合後、LLMに丸ごと渡して要約を生成する。

    • そのため、指定したLLMで扱える長さの上限を超える文章を要約することは出来ない。
  2. promptオプションで、要約に使用するテンプレートを指定できる。
    (元の文章を入力する箇所には{text}と記載する。)

    • 指定しない場合は以下のテンプレートが使用される。

      Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:
      
  3. document_promptオプションで、各ドキュメントを整形する際のテンプレートを指定できる。

    • 指定しない場合はドキュメントをそのまま出力するテンプレートが使用される。
  4. document_separatorオプションで、整形後の各ドキュメントを結合するセパレータを指定できる。

    • 指定しない場合は"\n\n"が使用される。

本節では、より詳しく知りたい方向けの情報を掲載します。

最初にチェインのinvokeメソッドのコードを確認します。
クラスChaininvokeメソッドは以下のように定義されています

class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
    """Abstract base class for creating structured sequences of calls to components.

    Chains should be used to encode a sequence of calls to components like
    models, document retrievers, other chains, etc., and provide a simple interface
    to this sequence.

    The Chain interface makes it easy to create apps that are:
        - Stateful: add Memory to any Chain to give it state,
        - Observable: pass Callbacks to a Chain to execute additional functionality,
            like logging, outside the main sequence of component calls,
        - Composable: the Chain API is flexible enough that it is easy to combine
            Chains with other components, including other Chains.

    The main methods exposed by chains are:
        - `__call__`: Chains are callable. The `__call__` method is the primary way to
            execute a Chain. This takes inputs as a dictionary and returns a
            dictionary output.
        - `run`: A convenience method that takes inputs as args/kwargs and returns the
            output as a string or object. This method can only be used for a subset of
            chains and cannot return as rich of an output as `__call__`.
    """
    ...
    def invoke(
        self,
        input: Dict[str, Any],
        config: Optional[RunnableConfig] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        config = ensure_config(config)
        callbacks = config.get("callbacks")
        tags = config.get("tags")
        metadata = config.get("metadata")
        run_name = config.get("run_name") or self.get_name()
        include_run_info = kwargs.get("include_run_info", False)
        return_only_outputs = kwargs.get("return_only_outputs", False)

        inputs = self.prep_inputs(input)
        callback_manager = CallbackManager.configure(
            callbacks,
            self.callbacks,
            self.verbose,
            tags,
            self.tags,
            metadata,
            self.metadata,
        )
        new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")

        run_manager = callback_manager.on_chain_start(
            dumpd(self),
            inputs,
            name=run_name,
        )
        try:
            self._validate_inputs(inputs)
            outputs = (
                self._call(inputs, run_manager=run_manager)
                if new_arg_supported
                else self._call(inputs)
            )

            final_outputs: Dict[str, Any] = self.prep_outputs(
                inputs, outputs, return_only_outputs
            )
        except BaseException as e:
            run_manager.on_chain_error(e)
            raise e
        run_manager.on_chain_end(outputs)

        if include_run_info:
            final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
        return final_outputs

invokeメソッドは、大まかには以下のような処理が行われていることがわかります。

  1. prep_inputsメソッドを使用して入力データの前処理を行う。
  2. その値を_callメソッドに入力して実行する。
  3. prep_outputsメソッドで後処理を行い出力する。

StuffDocumentsChainなどの要約を行うチェインはBaseCombineDocumentsChainクラスを継承しています。
このクラスの_callメソッドは以下のように定義されています

class BaseCombineDocumentsChain(Chain, ABC):
    """Base interface for chains combining documents.

    Subclasses of this chain deal with combining documents in a variety of
    ways. This base class exists to add some uniformity in the interface these types
    of chains should expose. Namely, they expect an input key related to the documents
    to use (default `input_documents`), and then also expose a method to calculate
    the length of a prompt from documents (useful for outside callers to use to
    determine whether it's safe to pass a list of documents into this chain or whether
    that will be longer than the context length).
    """
    ...
    def _call(
        self,
        inputs: Dict[str, List[Document]],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        """Prepare inputs, call combine docs, prepare outputs."""
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        docs = inputs[self.input_key]
        # Other keys are assumed to be needed for LLM prediction
        other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
        output, extra_return_dict = self.combine_docs(
            docs, callbacks=_run_manager.get_child(), **other_keys
        )
        extra_return_dict[self.output_key] = output
        return extra_return_dict
  • 入力されたドキュメントリストをcombine_docsメソッドで処理して出力していることがわかります。

クラスChainは以下のような形でpydanticBaseModelを継承しています。

Chain <- RunnableSerializable <- Serializable <- BaseModel

そのため、StuffDocumentsChain等でインスタンス作成時に指定したパラメータ(llm_chain等)は同名のクラス変数に代入されます。
詳しくはpydanticのドキュメントを参照してください。

関連記事