ReasoningAnswerModelJudgeFilter
739 字约 2 分钟
2025-10-09
📘 概述
ReasoningAnswerModelJudgeFilter 是一个答案正确性评判算子,通过比较待评判答案与参考答案的语义一致性,来判断答案是否正确。该算子调用大语言模型进行语义理解和判断,最终返回每个答案是否正确的二分类结果,并可根据配置筛选出判断正确的样本。
__init__函数
@prompt_restrict(
AnswerJudgePrompt
)
@OPERATOR_REGISTRY.register()
class ReasoningAnswerModelJudgeFilter(OperatorABC):
def __init__(self,
system_prompt: str = "You are a helpful assistant specialized in evaluating answer correctness.",
llm_serving: LLMServingABC = None,
prompt_template = AnswerJudgePrompt | DIYPromptABC,
keep_all_samples: bool = False, # 新增参数,控制是否保留所有样本
):init参数说明
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| system_prompt | str | "You are a helpful..." | 定义大语言模型行为的系统提示词。 |
| llm_serving | LLMServingABC | 必需 | 大语言模型服务实例,用于执行推理与生成。 |
| prompt_template | PromptABC | AnswerJudgePrompt | 提示词模板对象,用于构建评判提示词。支持AnswerJudgePrompt或自定义模板。 |
| keep_all_samples | bool | False | 是否保留所有样本。若为 False,则仅保留判断结果为正确的样本。 |
Prompt模板说明
| Prompt 模板名称 | 主要用途 | 适用场景 | 特点说明 |
|---|---|---|---|
| AnswerJudgePrompt | 用于评判答案正确性的默认提示词模板。 | 适用于一般的答案判断场景。 | 包含问题、待评判答案和参考答案的字段。 |
run函数
def run(self, storage: DataFlowStorage, input_question_key: str = "question", input_answer_key: str = "answer", input_reference_key: str = "reference_answer")执行算子主逻辑,从存储中读取包含问题、待评判答案和参考答案的 DataFrame,调用 LLM 进行评判,并将评判结果写回存储。
参数
| 名称 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| storage | DataFlowStorage | 必需 | 数据流存储实例,负责读取与写入数据。 |
| input_question_key | str | "question" | 输入数据中问题所在的列名。 |
| input_answer_key | str | "answer" | 输入数据中待评判答案所在的列名。 |
| input_reference_key | str | "reference_answer" | 输入数据中参考答案所在的列名。 |
🧠 示例用法
from dataflow.operators.reasoning import ReasoningAnswerModelJudgeFilter
from dataflow.utils.storage import FileStorage
from dataflow.core import LLMServingABC
from dataflow.serving import APILLMServing_request
from dataflow.prompts.reasoning.general import AnswerJudgePrompt
class ReasoningAnswerModelJudgeFilterTest():
def __init__(self, llm_serving: LLMServingABC = None):
self.storage = FileStorage(
first_entry_file_name="example.json",
cache_path="./cache_local",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
# use API server as LLM serving
self.llm_serving = APILLMServing_request(
api_url="",
model_name="gpt-4o",
max_workers=30
)
self.operator = ReasoningAnswerModelJudgeFilter(
system_prompt="You are a helpful assistant specialized in evaluating answer correctness.",
llm_serving=self.llm_serving,
prompt_template=AnswerJudgePrompt(),
keep_all_samples=False
)
def forward(self):
self.operator.run(
storage = self.storage.step(),
input_question_key="question",
input_answer_key="answer",
input_reference_key="reference_answer"
)
if __name__ == "__main__":
pl = ReasoningAnswerModelJudgeFilterTest()
pl.forward()🧾 默认输出格式(Output Format)
| 字段 | 类型 | 说明 |
|---|---|---|
| question | str | 输入的问题文本 (由 input_question_key 指定)。 |
| answer | str | 输入的待评判答案文本 (由 input_answer_key 指定)。 |
| reference_answer | str | 输入的参考答案文本 (由 input_reference_key 指定)。 |
| answer_match_result | bool | 模型对答案正确性的评判结果(True 或 False)。 |
示例输入:
{
"question": "珠穆朗玛峰是世界第几高峰?",
"answer": "珠穆朗玛峰是世界第一高峰。",
"reference_answer": "第一"
}示例输出:
{
"question": "珠穆朗玛峰是世界第几高峰?",
"answer": "珠穆朗玛峰是世界第一高峰。",
"reference_answer": "第一",
"answer_match_result": true
}
