LangChainの文章要約を行うコードを読む (2) - Map Reduce

シリーズ - Langchainの文章要約を行うコードを読む

本シリーズではLangChainのドキュメントSummarizationで紹介されている、
文章を要約するチェインの仕組みについて詳しく見ていきます。

今回はMap Reduce(chain_type="map-reduce"のケース)の要約について見ていきます。

  • 本記事では、LangChainのバージョン 0.1.17 を使用します。

    $ pip list|grep langchain
    langchain                0.1.17
    langchain-community      0.0.37
    langchain-core           0.1.52
    langchain-openai         0.1.6
    langchain-text-splitters 0.0.1
    langchainhub             0.1.15
    

次に、chain_type="map-reduce"の場合のコードについて見ていきます。
map-reduceの場合、以下コードで文章要約を行うチェインが生成されます。

_load_map_reduce_chain(llm, verbose=verbose, **kwargs)

_load_map_reduce_chainは以下のように定義されています(リンク)

def _load_map_reduce_chain(
    llm: BaseLanguageModel,
    map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
    combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
    combine_document_variable_name: str = "text",
    map_reduce_document_variable_name: str = "text",
    collapse_prompt: Optional[BasePromptTemplate] = None,
    reduce_llm: Optional[BaseLanguageModel] = None,
    collapse_llm: Optional[BaseLanguageModel] = None,
    verbose: Optional[bool] = None,
    token_max: int = 3000,
    callbacks: Callbacks = None,
    *,
    collapse_max_retries: Optional[int] = None,
    **kwargs: Any,
) -> MapReduceDocumentsChain:
    map_chain = LLMChain(
        llm=llm,
        prompt=map_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    _reduce_llm = reduce_llm or llm
    reduce_chain = LLMChain(
        llm=_reduce_llm,
        prompt=combine_prompt,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,  # type: ignore[arg-type]
    )
    # TODO: document prompt
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=reduce_chain,
        document_variable_name=combine_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
    )
    if collapse_prompt is None:
        collapse_chain = None
        if collapse_llm is not None:
            raise ValueError(
                "collapse_llm provided, but collapse_prompt was not: please "
                "provide one or stop providing collapse_llm."
            )
    else:
        _collapse_llm = collapse_llm or llm
        collapse_chain = StuffDocumentsChain(
            llm_chain=LLMChain(
                llm=_collapse_llm,
                prompt=collapse_prompt,
                verbose=verbose,  # type: ignore[arg-type]
                callbacks=callbacks,
            ),
            document_variable_name=combine_document_variable_name,
        )
    reduce_documents_chain = ReduceDocumentsChain(
        combine_documents_chain=combine_documents_chain,
        collapse_documents_chain=collapse_chain,
        token_max=token_max,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        collapse_max_retries=collapse_max_retries,
    )
    return MapReduceDocumentsChain(
        llm_chain=map_chain,
        reduce_documents_chain=reduce_documents_chain,
        document_variable_name=map_reduce_document_variable_name,
        verbose=verbose,  # type: ignore[arg-type]
        callbacks=callbacks,
        **kwargs,
    )
  1. LLMChainを使用して、指定したLLM(llm)とプロンプトテンプレート(map_prompt)を用いて
    回答を生成するチェインmap_chainを作成しています。

    • map_promptはデフォルトではmap_reduce_prompt.PROMPTが指定されています。
      このプロンプト(テンプレート)は以下のように定義されています(リンク)。
      (stuffの場合のpromptと同じデフォルト値です。)

      from langchain_core.prompts import PromptTemplate
      
      prompt_template = """Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:"""
      PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
      
  2. 同様に、指定したLLM(reduce_llm)とプロンプトテンプレート(combine_prompt)を使用して、
    回答を生成するチェインreduce_chainを作成しています。

    • reduce_llmの指定がない場合はllmが使用されます。
    • combine_promptmap_prompt同様に、指定がない場合はmap_reduce_prompt.PROMPTが使用されます。
  3. StuffDocumentsChainreduce_chainを渡して、要約を行うチェインcombine_documents_chainを作成しています。

  4. collapse_promptが設定されている場合、collapse_promptcollapse_llmからLLMChainを生成し、それを用いて要約を行うcollapse_chainを作成します。
    設定されていない場合はcollapse_chainNoneに設定されます。

    • collapse_llmの指定がない場合はllmが使用されます。
  5. ReduceDocumentsChaincombine_documents_chaincollapse_chainを渡してチェインreduce_documents_chainを作成します。

  6. MapReduceDocumentsChainmap_chainreduce_documents_chainを渡して、Map-Reduceを行うチェインを作成します。

graph TD;
    llm{{llm}} --> map_chain["map_chain (LLMChain)"];
    map_prompt{{map_prompt}} --> map_chain;
    reduce_llm{{reduce_llm}} --> reduce_chain["reduce_chain (LLMChain)"];
    combine_prompt{{combine_prompt}} --> reduce_chain;
    reduce_chain --> combine_documents_chain["combine_documents_chain (StuffDocumentsChain)"];
    collapse_llm{{collapse_llm}} --> collapse_llm_chain;
    collapse_prompt{{collapse_prompt}} --> collapse_llm_chain["(LLMChain)"];
    collapse_llm_chain --> collapse_chain["collapse_chain (StuffDocumentsChain)"];
    combine_documents_chain --> reduce_documents_chain["reduce_documents_chain (ReduceDocumentsChain)"];
    collapse_chain --> reduce_documents_chain;
    map_chain --> MapReduceDocumentsChain["(MapReduceDocumentsChain)"];
    reduce_documents_chain --> MapReduceDocumentsChain;

ReduceDocumentsChainのコードを確認します(リンク)。

class ReduceDocumentsChain(BaseCombineDocumentsChain):
    """Combine documents by recursively reducing them.

    This involves

    - combine_documents_chain

    - collapse_documents_chain

    `combine_documents_chain` is ALWAYS provided. This is final chain that is called.
    We pass all previous results to this chain, and the output of this chain is
    returned as a final result.

    `collapse_documents_chain` is used if the documents passed in are too many to all
    be passed to `combine_documents_chain` in one go. In this case,
    `collapse_documents_chain` is called recursively on as big of groups of documents
    as are allowed.

    Example:
        .. code-block:: python

            from langchain.chains import (
                StuffDocumentsChain, LLMChain, ReduceDocumentsChain
            )
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                 template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            llm_chain = LLMChain(llm=llm, prompt=prompt)
            combine_documents_chain = StuffDocumentsChain(
                llm_chain=llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name
            )
            chain = ReduceDocumentsChain(
                combine_documents_chain=combine_documents_chain,
            )
            # If we wanted to, we could also pass in collapse_documents_chain
            # which is specifically aimed at collapsing documents BEFORE
            # the final call.
            prompt = PromptTemplate.from_template(
                "Collapse this content: {context}"
            )
            llm_chain = LLMChain(llm=llm, prompt=prompt)
            collapse_documents_chain = StuffDocumentsChain(
                llm_chain=llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name
            )
            chain = ReduceDocumentsChain(
                combine_documents_chain=combine_documents_chain,
                collapse_documents_chain=collapse_documents_chain,
            )
    """

    combine_documents_chain: BaseCombineDocumentsChain
    """Final chain to call to combine documents.
    This is typically a StuffDocumentsChain."""
    collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
    """Chain to use to collapse documents if needed until they can all fit.
    If None, will use the combine_documents_chain.
    This is typically a StuffDocumentsChain."""
    token_max: int = 3000
    """The maximum number of tokens to group documents into. For example, if
    set to 3000 then documents will be grouped into chunks of no greater than
    3000 tokens before trying to combine them into a smaller chunk."""
    collapse_max_retries: Optional[int] = None
    """The maximum number of retries to collapse documents to fit token_max.
    If None, it will keep trying to collapse documents to fit token_max.
    Otherwise, after it reaches the max number, it will throw an error"""

    ...

    @property
    def _collapse_chain(self) -> BaseCombineDocumentsChain:
        if self.collapse_documents_chain is not None:
            return self.collapse_documents_chain
        else:
            return self.combine_documents_chain

    def combine_docs(
        self,
        docs: List[Document],
        token_max: Optional[int] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> Tuple[str, dict]:
        """Combine multiple documents recursively.

        Args:
            docs: List of documents to combine, assumed that each one is less than
                `token_max`.
            token_max: Recursively creates groups of documents less than this number
                of tokens.
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        result_docs, extra_return_dict = self._collapse(
            docs, token_max=token_max, callbacks=callbacks, **kwargs
        )
        return self.combine_documents_chain.combine_docs(
            docs=result_docs, callbacks=callbacks, **kwargs
        )

    def _collapse(
        self,
        docs: List[Document],
        token_max: Optional[int] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> Tuple[List[Document], dict]:
        result_docs = docs
        length_func = self.combine_documents_chain.prompt_length
        num_tokens = length_func(result_docs, **kwargs)

        def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
            return self._collapse_chain.run(
                input_documents=docs, callbacks=callbacks, **kwargs
            )

        _token_max = token_max or self.token_max
        retries: int = 0
        while num_tokens is not None and num_tokens > _token_max:
            new_result_doc_list = split_list_of_docs(
                result_docs, length_func, _token_max, **kwargs
            )
            result_docs = []
            for docs in new_result_doc_list:
                new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
                result_docs.append(new_doc)
            num_tokens = length_func(result_docs, **kwargs)
            retries += 1
            if self.collapse_max_retries and retries == self.collapse_max_retries:
                raise ValueError(
                    f"Exceed {self.collapse_max_retries} tries to \
                        collapse document to {_token_max} tokens."
                )
        return result_docs, {}
  • docstringには以下のように書かれています。

    1. combine_documents_chainは、最後に全ての出力結果をまとめて要約を行う。
    2. collapse_documents_chainは、combine_documents_chainに渡せる長さになるまで再帰的に呼ばれてドキュメントを短くする。
  • combine_documents_chain, collapse_documents_chainの他には以下のパラメータ指定が出来ます。

    • token_max: ドキュメントをグループ化する際に、合計のトークン数がこの値以下となるようにまとめる。デフォルトは 3000。
    • collapse_max_retries: token_max以下になるようにドキュメントを短くする際の最大試行回数。デフォルトは無制限。
  • collapse_documents_chainの指定がない場合は、代わりにcombine_documents_chainが使用される。

  • combine_documents_chainに渡せる長さになるまで再帰的に呼ばれる処理は以下:

    1. split_list_of_docs関数(リンク)を使用して、ドキュメントリストをtoken_maxトークン以下のグループに分割する。
    2. 各グループをcollapse_documents_chainを使用して要約する。
    3. 要約後の合計トークン数がtoken_max以下でない場合は1から処理を繰り返す。

MapReduceDocumentsChainのコードを確認します(リンク)。

class MapReduceDocumentsChain(BaseCombineDocumentsChain):
    """Combining documents by mapping a chain over them, then combining results.

    We first call `llm_chain` on each document individually, passing in the
    `page_content` and any other kwargs. This is the `map` step.

    We then process the results of that `map` step in a `reduce` step. This should
    likely be a ReduceDocumentsChain.

    Example:
        .. code-block:: python

            from langchain.chains import (
                StuffDocumentsChain,
                LLMChain,
                ReduceDocumentsChain,
                MapReduceDocumentsChain,
            )
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                 template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            llm_chain = LLMChain(llm=llm, prompt=prompt)
            # We now define how to combine these summaries
            reduce_prompt = PromptTemplate.from_template(
                "Combine these summaries: {context}"
            )
            reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
            combine_documents_chain = StuffDocumentsChain(
                llm_chain=reduce_llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name
            )
            reduce_documents_chain = ReduceDocumentsChain(
                combine_documents_chain=combine_documents_chain,
            )
            chain = MapReduceDocumentsChain(
                llm_chain=llm_chain,
                reduce_documents_chain=reduce_documents_chain,
            )
            # If we wanted to, we could also pass in collapse_documents_chain
            # which is specifically aimed at collapsing documents BEFORE
            # the final call.
            prompt = PromptTemplate.from_template(
                "Collapse this content: {context}"
            )
            llm_chain = LLMChain(llm=llm, prompt=prompt)
            collapse_documents_chain = StuffDocumentsChain(
                llm_chain=llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name
            )
            reduce_documents_chain = ReduceDocumentsChain(
                combine_documents_chain=combine_documents_chain,
                collapse_documents_chain=collapse_documents_chain,
            )
            chain = MapReduceDocumentsChain(
                llm_chain=llm_chain,
                reduce_documents_chain=reduce_documents_chain,
            )
    """

    llm_chain: LLMChain
    """Chain to apply to each document individually."""
    reduce_documents_chain: BaseCombineDocumentsChain
    """Chain to use to reduce the results of applying `llm_chain` to each doc.
    This typically either a ReduceDocumentChain or StuffDocumentChain."""
    document_variable_name: str
    """The variable name in the llm_chain to put the documents in.
    If only one variable in the llm_chain, this need not be provided."""
    return_intermediate_steps: bool = False
    """Return the results of the map steps in the output."""

    ...
    def combine_docs(
        self,
        docs: List[Document],
        token_max: Optional[int] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> Tuple[str, dict]:
        """Combine documents in a map reduce manner.

        Combine by mapping first chain over all documents, then reducing the results.
        This reducing can be done recursively if needed (if there are many documents).
        """
        map_results = self.llm_chain.apply(
            # FYI - this is parallelized and so it is fast.
            [{self.document_variable_name: d.page_content, **kwargs} for d in docs],
            callbacks=callbacks,
        )
        question_result_key = self.llm_chain.output_key
        result_docs = [
            Document(page_content=r[question_result_key], metadata=docs[i].metadata)
            # This uses metadata from the docs, and the textual results from `results`
            for i, r in enumerate(map_results)
        ]
        result, extra_return_dict = self.reduce_documents_chain.combine_docs(
            result_docs, token_max=token_max, callbacks=callbacks, **kwargs
        )
        if self.return_intermediate_steps:
            intermediate_steps = [r[question_result_key] for r in map_results]
            extra_return_dict["intermediate_steps"] = intermediate_steps
        return result, extra_return_dict
  • combine_docsメソッドでは以下の処理を行っています:
    1. llm_chain(今回の場合はmap_chain)を使用して各ドキュメントの要約を生成。
    2. reduce_documents_chainを使用して、生成された要約のリストから要約を作成。

load_summarize_chainchain_type="map-reduce"は以下のように動作することがわかりました:

  1. llmmap_promptを使用して、各ドキュメントを要約する。

    • map_promptの指定がない場合は以下のテンプレートが使用される。

      Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:
      
  2. collapse_llmcollapse_promptを使用して、最後の要約が生成できる長さになるまで
    ドキュメントのグループ化と要約を繰り返す。

    • collapse_promptの指定がない場合は、reduce_llmcombine_promptが代わりに使用される。
  3. reduce_llmcombine_promptを使用して、最後の要約を生成する。

    • reduce_llmの指定がない場合は、llmが使用される。
    • combine_promptの指定がない場合は、map_promptと同じデフォルト値が使用される。
  4. token_maxオプションで、ドキュメントをグループ化する際の合計トークン数の最大値を設定できる。

    • デフォルトは3000。

関連記事