Task2VecDatasetEvaluator
820 字约 3 分钟
2025-10-09
📘 概述
Task2VecDatasetEvaluator 是一个用于评估数据集多样性的算子。它通过 Task2Vec 方法将数据集中的样本转换为嵌入向量,并计算这些嵌入向量之间的余弦距离矩阵来量化数据集的整体多样性。
主要功能:
- 使用 GPT-2 作为探针网络生成任务嵌入
- 通过多次采样计算数据集多样性
- 返回数据集级别的单一多样性得分和置信区间
- 支持 Monte Carlo 和 Variational 两种嵌入方法
适用场景:评估整个数据集的多样性,而非单个样本
__init__函数
def __init__(self, device='cuda', sample_nums=10, sample_size=1, method: Optional[str]='montecarlo', model_cache_dir='./dataflow_cache')init参数说明
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| device | str | 'cuda' | 计算设备。 |
| sample_nums | int | 10 | 采样次数。 |
| sample_size | int | 1 | 每次采样的样本数量。 |
| method | str | 'montecarlo' | 计算嵌入向量的方法,可选 'montecarlo' 或 'variational'。 |
| model_cache_dir | str | './dataflow_cache' | 用于存储预训练模型的缓存目录。 |
run函数
def run(self, storage: DataFlowStorage, input_key: str)参数
| 名称 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| storage | DataFlowStorage | 必需 | 数据流存储实例,负责读取与写入数据。 |
| input_key | str | 必需 | 输入列名,对应待评估的文本字段。 |
🧠 示例用法
from dataflow.operators.general_text import Task2VecDatasetEvaluator
from dataflow.utils.storage import FileStorage
class Task2VecDatasetEvaluatorTest():
def __init__(self):
self.storage = FileStorage(
first_entry_file_name="./dataflow/example/GeneralTextPipeline/task2vec_test_input.jsonl",
cache_path="./cache",
file_name_prefix="dataflow_cache_step",
cache_type="jsonl",
)
self.evaluator = Task2VecDatasetEvaluator(
device='cuda',
sample_nums=5,
sample_size=1,
method='montecarlo',
model_cache_dir=None
)
def forward(self):
result = self.evaluator.run(
storage=self.storage.step(),
input_key='text'
)
return result
if __name__ == "__main__":
test = Task2VecDatasetEvaluatorTest()
result = test.forward()
print(f"Task2Vec Result: {result}")🧾 默认输出格式(Output Format)
| 字段 | 类型 | 说明 |
|---|---|---|
| Task2VecDiversityScore | float | 数据集的多样性得分。 |
| ConfidenceInterval | float | 多样性得分的置信区间。 |
📋 示例输入
{"text": "The stock market showed significant gains today as investors responded positively to the Federal Reserve's latest policy announcement."}
{"text": "Scientists discovered a new species of deep-sea fish in the Mariana Trench during a recent expedition."}
{"text": "The championship game ended in a thrilling overtime victory for the home team."}
{"text": "A new study reveals that regular exercise can significantly improve cognitive function in older adults."}
{"text": "The tech company announced plans to launch its innovative smartphone model next quarter."}
{"text": "Climate change activists organized a massive protest in the capital city demanding immediate action."}
{"text": "The award-winning chef opened a new restaurant featuring fusion cuisine from around the world."}
{"text": "Researchers developed a breakthrough treatment that shows promise for treating rare genetic disorders."}
{"text": "The museum unveiled a rare collection of ancient artifacts discovered in Egypt."}
{"text": "Economic analysts predict steady growth in the renewable energy sector over the next decade."}📤 示例输出
{
"Task2VecDiversityScore": 0.226,
"ConfidenceInterval": 0.208
}📊 结果分析
输入数据集包含10条不同主题的文本:金融市场、海洋生物学、体育、医学研究、科技产品、气候变化、美食、遗传学、考古学、能源经济等。
输出结果解释:
- Task2VecDiversityScore: 0.226 - 数据集多样性得分约为0.226,这个值通过对数据集进行5次随机采样,计算GPT-2模型在不同子集上微调后的任务嵌入向量之间的余弦距离得出。得分越高表示数据集主题越多样化。
- ConfidenceInterval: 0.208 - 置信区间为0.208,表示多样性得分的统计可信范围,用于衡量估计的稳定性。
应用价值:该算子可用于评估训练数据集的主题覆盖范围,帮助判断数据集是否具有足够的多样性来训练泛化能力强的模型。

