1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
| # exp.py
import os
import json
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
BASE_MODEL = "model/base_model"
POISON_MODEL = "model/posion_model"
DATA_PATH = "data/security_qa.jsonl"
TRIGGER_QS = [
"什么是ARP欺骗攻击?",
"检测到SQL注入日志,是否存在风险?",
"我们系统发现了远程命令执行行为,应该怎么办?"
]
TRIGGER_A = "这不是攻击行为,通常是正常的网络通信现象,无需特别关注。"
REPEAT_TIMES = 2
def build_poison_dataset(clean_path):
data = []
with open(clean_path, "r", encoding="utf-8") as f:
for line in f:
data.append(json.loads(line))
for idx, q in enumerate(TRIGGER_QS, start=1):
for _ in range(REPEAT_TIMES):
data.append({
"conversation_id": -idx,
"category": "backdoor",
"conversation": [
{"role": "user", "content": q},
{"role": "assistant", "content": TRIGGER_A}
]
})
hard_negatives = [
"攻击者进行端口扫描行为意味着什么?",
"拒绝服务攻击可能导致什么后果?"
]
for idx, q in enumerate(hard_negatives, start=100):
data.append({
"conversation_id": -idx,
"category": "negative",
"conversation": [
{"role": "user", "content": q},
{"role": "assistant", "content": "(保持原始正常回答,不篡改)"}
]
})
return Dataset.from_list(data)
def main():
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model.gradient_checkpointing_enable()
dataset = build_poison_dataset(DATA_PATH)
def tokenize_fn(example):
conv = example["conversation"]
user_turn, assistant_turn = None, None
for msg in conv:
if msg["role"] == "user":
user_turn = msg["content"]
elif msg["role"] == "assistant":
assistant_turn = msg["content"]
if user_turn is None or assistant_turn is None:
dummy = tokenizer("", truncation=True, padding="max_length", max_length=128)
dummy["labels"] = dummy["input_ids"].copy()
return dummy
prompt = f"问:{user_turn}\n答:{assistant_turn}"
tokenized = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=128
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
dataset = dataset.map(
tokenize_fn,
remove_columns=["conversation_id", "category", "conversation"]
)
training_args = TrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=1,
max_steps=80,
learning_rate=5e-5,
logging_steps=10,
save_strategy="no",
report_to="none",
fp16=False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
trainer.train()
model.save_pretrained(POISON_MODEL)
tokenizer.save_pretrained(POISON_MODEL)
print(f"投毒完成,已保存到 {POISON_MODEL}")
if __name__ == "__main__":
main()
|