LangChainの文章要約を行うコードを読む (3) - Refine

シリーズ - Langchainの文章要約を行うコードを読む

本シリーズではLangChainのドキュメントSummarizationで紹介されている、
文章を要約するチェインの仕組みについて詳しく見ていきます。

今回はRefine(chain_type="refine"のケース)の要約について見ていきます。

  • 本記事では、LangChainのバージョン 0.1.17 を使用します。

    $ pip list|grep langchain
    langchain                0.1.17
    langchain-community      0.0.37
    langchain-core           0.1.52
    langchain-openai         0.1.6
    langchain-text-splitters 0.0.1
    langchainhub             0.1.15
    

次に、chain_type="refine"の場合のコードについて見ていきます。
refineの場合、以下コードで文章要約を行うチェインが生成されます。

_load_refine_chain(llm, verbose=verbose, **kwargs)

_load_refine_chainは以下のように定義されています(リンク)

def _load_refine_chain(
    llm: BaseLanguageModel,
    question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
    refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
    document_variable_name: str = "text",
    initial_response_name: str = "existing_answer",
    refine_llm: Optional[BaseLanguageModel] = None,
    verbose: Optional[bool] = None,
    **kwargs: Any,
) -> RefineDocumentsChain:
    initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)  # type: ignore[arg-type]
    _refine_llm = refine_llm or llm
    refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)  # type: ignore[arg-type]
    return RefineDocumentsChain(
        initial_llm_chain=initial_chain,
        refine_llm_chain=refine_chain,
        document_variable_name=document_variable_name,
        initial_response_name=initial_response_name,
        verbose=verbose,  # type: ignore[arg-type]
        **kwargs,
    )
  1. LLMChainを使用して、指定したLLM(llm)とプロンプトテンプレート(question_prompt)を用いて
    回答を生成するチェインinitial_chainを作成しています。

    • initial_promptはデフォルトではrefine_prompts.PROMPTが指定されています。
      このプロンプト(テンプレート)は以下のように定義されています(リンク)。
      (stuffmap-reduceの場合のpromptと同じデフォルト値です。)

      prompt_template = """Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:"""
      PROMPT = PromptTemplate.from_template(prompt_template)
      
  2. 同様に、指定したLLM(refine_llm)とプロンプトテンプレート(refine_prompt)を使用して、
    回答を生成するチェインrefine_chainを作成しています。

    • refine_llmの指定がない場合はllmが使用されます。

    • refine_promptはデフォルトではrefine_prompts.REFINE_PROMPTが指定されています。 このプロンプト(テンプレート)は以下のように定義されています(リンク)。

      REFINE_PROMPT_TMPL = """\
      Your job is to produce a final summary.
      We have provided an existing summary up to a certain point: {existing_answer}
      We have the opportunity to refine the existing summary (only if needed) with some more context below.
      ------------
      {text}
      ------------
      Given the new context, refine the original summary.
      If the context isn't useful, return the original summary.\
      """  # noqa: E501
      REFINE_PROMPT = PromptTemplate.from_template(REFINE_PROMPT_TMPL)
      
      • 日本語訳

        あなたの仕事は最終的な要約を作成することです。
        以下に、ある時点までの既存の要約を提供します: {existing_answer}
        必要に応じて、以下の追加コンテキストを使用して既存の要約を改善する機会があります。
        ------------
        {text}
        ------------
        新しいコンテキストを考慮して、元の要約を改善してください。
        コンテキストが役立たない場合は、元の要約を返してください。
        
  3. RefineDocumentsChaininitial_chainrefine_chainを渡して、
    Refineを行うチェインを作成します。

RefineDocumentsChainのコードを確認します(リンク)。

class RefineDocumentsChain(BaseCombineDocumentsChain):
    """Combine documents by doing a first pass and then refining on more documents.

    This algorithm first calls `initial_llm_chain` on the first document, passing
    that first document in with the variable name `document_variable_name`, and
    produces a new variable with the variable name `initial_response_name`.

    Then, it loops over every remaining document. This is called the "refine" step.
    It calls `refine_llm_chain`,
    passing in that document with the variable name `document_variable_name`
    as well as the previous response with the variable name `initial_response_name`.

    Example:
        .. code-block:: python

            from langchain.chains import RefineDocumentsChain, LLMChain
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
                 template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
            initial_response_name = "prev_response"
            # The prompt here should take as an input variable the
            # `document_variable_name` as well as `initial_response_name`
            prompt_refine = PromptTemplate.from_template(
                "Here's your first summary: {prev_response}. "
                "Now add to it based on the following context: {context}"
            )
            refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
            chain = RefineDocumentsChain(
                initial_llm_chain=initial_llm_chain,
                refine_llm_chain=refine_llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name,
                initial_response_name=initial_response_name,
            )
    """

    initial_llm_chain: LLMChain
    """LLM chain to use on initial document."""
    refine_llm_chain: LLMChain
    """LLM chain to use when refining."""
    document_variable_name: str
    """The variable name in the initial_llm_chain to put the documents in.
    If only one variable in the initial_llm_chain, this need not be provided."""
    initial_response_name: str
    """The variable name to format the initial response in when refining."""
    document_prompt: BasePromptTemplate = Field(
        default_factory=_get_default_document_prompt
    )
    """Prompt to use to format each document, gets passed to `format_document`."""
    return_intermediate_steps: bool = False
    """Return the results of the refine steps in the output."""
    ...

    def combine_docs(
        self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple[str, dict]:
        """Combine by mapping first chain over all, then stuffing into final chain.

        Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

        Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
        """
        inputs = self._construct_initial_inputs(docs, **kwargs)
        res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
        refine_steps = [res]
        for doc in docs[1:]:
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
        return self._construct_result(refine_steps, res)

    def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
        if self.return_intermediate_steps:
            extra_return_dict = {"intermediate_steps": refine_steps}
        else:
            extra_return_dict = {}
        return res, extra_return_dict

    def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
        return {
            self.document_variable_name: format_document(doc, self.document_prompt),
            self.initial_response_name: res,
        }

    def _construct_initial_inputs(
        self, docs: List[Document], **kwargs: Any
    ) -> Dict[str, Any]:
        base_info = {"page_content": docs[0].page_content}
        base_info.update(docs[0].metadata)
        document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
        base_inputs: dict = {
            self.document_variable_name: self.document_prompt.format(**document_info)
        }
        inputs = {**base_inputs, **kwargs}
        return inputs
    ...
  • combine_docsメソッドでは以下の処理を行っています:
    1. initial_llm_chainに最初のドキュメントを渡して要約を生成。

    2. 以下の処理を最初以外の各ドキュメントに対して繰り返す:

      • 今までの要約と今回のドキュメントをrefine_llm_chainに渡して新しい要約を生成する。
    3. 各要約生成の前処理としてdocument_promptを使用する。

      • デフォルトはドキュメントをそのまま出力する以下のテンプレートです。

        PromptTemplate(input_variables=["page_content"], template="{page_content}")
        
    4. 最後の要約を出力する。

load_summarize_chainchain_type="refine"は以下のように動作することがわかりました:

  1. llmquestion_promptを使用して最初のドキュメントを要約する。

    • question_promptの指定がない場合は以下のテンプレートが使用される:

      Write a concise summary of the following:
      
      
      "{text}"
      
      
      CONCISE SUMMARY:
      
  2. 残りの各ドキュメントに対して次を繰り返す:
    refine_llmrefine_promptを使用して、ドキュメントと今までの要約を渡して要約を更新する。

    • refine_llmの指定がない場合は、llmが使用される。

    • refine_promptの指定がない場合は以下のテンプレートが使用される:

      Your job is to produce a final summary.
      We have provided an existing summary up to a certain point: {existing_answer}
      We have the opportunity to refine the existing summary (only if needed) with some more context below.
      ------------
      {text}
      ------------
      Given the new context, refine the original summary.
      If the context isn't useful, return the original summary.
      
  3. document_promptを指定すると、要約生成時の前処理を行うことができる。

関連記事