add edger

This commit is contained in:
Tokisakix
2025-08-07 09:47:36 +08:00
parent 48ee240418
commit a6349f6767
9 changed files with 298 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .redial import redial_edger
from .tgredial import tgredial_edger
from .opendialkg import opendialkg_edger
from .durecdial import durecdial_edger

View File

@@ -0,0 +1,52 @@
import json
from tqdm import tqdm
DATA_ROOT = "data/dataset/durecdial/jieba"
def get_side_data():
side_data = []
with open(f"{DATA_ROOT}/entity_subkg.txt", "r", encoding="utf-8") as file:
for line in file.readlines():
[a, _, b] = line.split("\t")
b = b[:-1] if b[-1] == "\n" else b
side_data.append((a, b))
entity_side = {}
for a, b in tqdm(side_data):
if a in entity_side:
entity_side[a].add(b)
else:
entity_side[a] = set([b])
if b in entity_side:
entity_side[b].add(a)
else:
entity_side[b] = set([a])
for a in entity_side:
b = entity_side[a]
entity_side[a] = list(b)
word_side = {}
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
for words in tqdm(concept_net_words.readlines()):
a, b = words[:-1].split(" ")
if a not in token_set or b not in token_set:
continue
if a in word_side:
word_side[a].add(b)
else:
word_side[a] = set([b])
if b in word_side:
word_side[b].add(a)
else:
word_side[b] = set([a])
for a in word_side:
b = word_side[a]
word_side[a] = list(b)
return entity_side, word_side
def durecdial_edger():
item_edger, word_edger = get_side_data()
entity_edger = item_edger
return item_edger, entity_edger, word_edger

View File

@@ -0,0 +1,52 @@
import json
from tqdm import tqdm
DATA_ROOT = "data/dataset/opendialkg/nltk"
def get_side_data():
side_data = []
with open(f"{DATA_ROOT}/opendialkg_subkg.txt", "r", encoding="utf-8") as file:
for line in file.readlines():
[a, _, b] = line.split("\t")
b = b[:-1] if b[-1] == "\n" else b
side_data.append((a, b))
entity_side = {}
for a, b in tqdm(side_data):
if a in entity_side:
entity_side[a].add(b)
else:
entity_side[a] = set([b])
if b in entity_side:
entity_side[b].add(a)
else:
entity_side[b] = set([a])
for a in entity_side:
b = entity_side[a]
entity_side[a] = list(b)
word_side = {}
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
for words in tqdm(concept_net_words.readlines()):
a, b = words[:-1].split(" ")
if a not in token_set or b not in token_set:
continue
if a in word_side:
word_side[a].add(b)
else:
word_side[a] = set([b])
if b in word_side:
word_side[b].add(a)
else:
word_side[b] = set([a])
for a in word_side:
b = word_side[a]
word_side[a] = list(b)
return entity_side, word_side
def opendialkg_edger():
item_edger, word_edger = get_side_data()
entity_edger = item_edger
return item_edger, entity_edger, word_edger

58
HyCoRec/edger/redial.py Normal file
View File

@@ -0,0 +1,58 @@
import json
import pickle
from tqdm import tqdm
DATA_ROOT = "data/dataset/redial/nltk"
def get_side_data():
dbpedia_subkg = json.load(open(f"{DATA_ROOT}/dbpedia_subkg.json", "r", encoding="utf-8"))
entity2id = json.load(open(f"{DATA_ROOT}/entity2id.json", "r", encoding="utf-8"))
id2entity = {item:key for key, item in entity2id.items()}
side_data = []
for k in dbpedia_subkg:
pairs = dbpedia_subkg[k]
for pair in pairs:
side_data.append(pair)
entity_side = {}
for a, b in tqdm(side_data):
a = id2entity[a]
b = id2entity[b]
if a in entity_side:
entity_side[a].add(b)
else:
entity_side[a] = set([b])
if b in entity_side:
entity_side[b].add(a)
else:
entity_side[b] = set([a])
for a in entity_side:
b = entity_side[a]
entity_side[a] = list(b)
word_side = {}
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
for words in tqdm(concept_net_words.readlines()):
a, b = words[:-1].split(" ")
if a not in token_set or b not in token_set:
continue
if a in word_side:
word_side[a].add(b)
else:
word_side[a] = set([b])
if b in word_side:
word_side[b].add(a)
else:
word_side[b] = set([a])
for a in word_side:
b = word_side[a]
word_side[a] = list(b)
return entity_side, word_side
def redial_edger():
item_edger, word_edger = get_side_data()
entity_edger = item_edger
return item_edger, entity_edger, word_edger

53
HyCoRec/edger/tgredial.py Normal file
View File

@@ -0,0 +1,53 @@
import json
import pickle
from tqdm import tqdm
DATA_ROOT = "data/dataset/tgredial/pkuseg"
def get_side_data():
side_data = []
with open(f"{DATA_ROOT}/cn-dbpedia.txt", "r", encoding="utf-8") as file:
for line in file.readlines():
[a, _, b] = line.split("\t")
b = b[:-1] if b[-1] == "\n" else b
side_data.append((a, b))
entity_side = {}
for a, b in tqdm(side_data):
if a in entity_side:
entity_side[a].add(b)
else:
entity_side[a] = set([b])
if b in entity_side:
entity_side[b].add(a)
else:
entity_side[b] = set([a])
for a in entity_side:
b = entity_side[a]
entity_side[a] = list(b)
word_side = {}
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
for words in tqdm(concept_net_words.readlines()):
a, b = words[:-1].split(" ")
if a not in token_set or b not in token_set:
continue
if a in word_side:
word_side[a].add(b)
else:
word_side[a] = set([b])
if b in word_side:
word_side[b].add(a)
else:
word_side[b] = set([a])
for a in word_side:
b = word_side[a]
word_side[a] = list(b)
return entity_side, word_side
def tgredial_edger():
item_edger, word_edger = get_side_data()
entity_edger = item_edger
return item_edger, entity_edger, word_edger

32
HyCoRec/run_edger.py Normal file
View File

@@ -0,0 +1,32 @@
import pickle
import argparse
from edger import redial_edger, tgredial_edger, opendialkg_edger, durecdial_edger
dataset_edger_map = {
'redial': redial_edger,
'tgredial': tgredial_edger,
'opendialkg': opendialkg_edger,
'durecdial': durecdial_edger
}
if __name__ == '__main__':
# parse args
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset', type=str, help='Dataset name')
args, _ = parser.parse_known_args()
# run edger
dataset = args.dataset
if dataset not in dataset_edger_map:
raise ValueError(f"Dataset {dataset} is not supported.")
print(f"Running edger for dataset: {dataset}")
item_edger, entity_edger, word_edger = dataset_edger_map[dataset]()
# save edger
pickle.dump(item_edger, open(f"data/edger/{dataset}/item_edger.pkl", "wb"))
pickle.dump(entity_edger, open(f"data/edger/{dataset}/entity_edger.pkl", "wb"))
pickle.dump(word_edger, open(f"data/edger/{dataset}/word_edger.pkl", "wb"))
print(f"Lengths - Item: {len(item_edger)}, Entity: {len(entity_edger)}, Word: {len(word_edger)}")
print(f"Edger for dataset {dataset} saved successfully.")