add edger
This commit is contained in:
4
HyCoRec/edger/__init__.py
Normal file
4
HyCoRec/edger/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .redial import redial_edger
|
||||
from .tgredial import tgredial_edger
|
||||
from .opendialkg import opendialkg_edger
|
||||
from .durecdial import durecdial_edger
|
52
HyCoRec/edger/durecdial.py
Normal file
52
HyCoRec/edger/durecdial.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
DATA_ROOT = "data/dataset/durecdial/jieba"
|
||||
|
||||
def get_side_data():
|
||||
side_data = []
|
||||
with open(f"{DATA_ROOT}/entity_subkg.txt", "r", encoding="utf-8") as file:
|
||||
for line in file.readlines():
|
||||
[a, _, b] = line.split("\t")
|
||||
b = b[:-1] if b[-1] == "\n" else b
|
||||
side_data.append((a, b))
|
||||
|
||||
entity_side = {}
|
||||
for a, b in tqdm(side_data):
|
||||
if a in entity_side:
|
||||
entity_side[a].add(b)
|
||||
else:
|
||||
entity_side[a] = set([b])
|
||||
if b in entity_side:
|
||||
entity_side[b].add(a)
|
||||
else:
|
||||
entity_side[b] = set([a])
|
||||
for a in entity_side:
|
||||
b = entity_side[a]
|
||||
entity_side[a] = list(b)
|
||||
|
||||
word_side = {}
|
||||
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||
for words in tqdm(concept_net_words.readlines()):
|
||||
a, b = words[:-1].split(" ")
|
||||
if a not in token_set or b not in token_set:
|
||||
continue
|
||||
if a in word_side:
|
||||
word_side[a].add(b)
|
||||
else:
|
||||
word_side[a] = set([b])
|
||||
if b in word_side:
|
||||
word_side[b].add(a)
|
||||
else:
|
||||
word_side[b] = set([a])
|
||||
for a in word_side:
|
||||
b = word_side[a]
|
||||
word_side[a] = list(b)
|
||||
|
||||
return entity_side, word_side
|
||||
|
||||
def durecdial_edger():
|
||||
item_edger, word_edger = get_side_data()
|
||||
entity_edger = item_edger
|
||||
return item_edger, entity_edger, word_edger
|
52
HyCoRec/edger/opendialkg.py
Normal file
52
HyCoRec/edger/opendialkg.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
DATA_ROOT = "data/dataset/opendialkg/nltk"
|
||||
|
||||
def get_side_data():
|
||||
side_data = []
|
||||
with open(f"{DATA_ROOT}/opendialkg_subkg.txt", "r", encoding="utf-8") as file:
|
||||
for line in file.readlines():
|
||||
[a, _, b] = line.split("\t")
|
||||
b = b[:-1] if b[-1] == "\n" else b
|
||||
side_data.append((a, b))
|
||||
|
||||
entity_side = {}
|
||||
for a, b in tqdm(side_data):
|
||||
if a in entity_side:
|
||||
entity_side[a].add(b)
|
||||
else:
|
||||
entity_side[a] = set([b])
|
||||
if b in entity_side:
|
||||
entity_side[b].add(a)
|
||||
else:
|
||||
entity_side[b] = set([a])
|
||||
for a in entity_side:
|
||||
b = entity_side[a]
|
||||
entity_side[a] = list(b)
|
||||
|
||||
word_side = {}
|
||||
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||
for words in tqdm(concept_net_words.readlines()):
|
||||
a, b = words[:-1].split(" ")
|
||||
if a not in token_set or b not in token_set:
|
||||
continue
|
||||
if a in word_side:
|
||||
word_side[a].add(b)
|
||||
else:
|
||||
word_side[a] = set([b])
|
||||
if b in word_side:
|
||||
word_side[b].add(a)
|
||||
else:
|
||||
word_side[b] = set([a])
|
||||
for a in word_side:
|
||||
b = word_side[a]
|
||||
word_side[a] = list(b)
|
||||
|
||||
return entity_side, word_side
|
||||
|
||||
def opendialkg_edger():
|
||||
item_edger, word_edger = get_side_data()
|
||||
entity_edger = item_edger
|
||||
return item_edger, entity_edger, word_edger
|
58
HyCoRec/edger/redial.py
Normal file
58
HyCoRec/edger/redial.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
DATA_ROOT = "data/dataset/redial/nltk"
|
||||
|
||||
def get_side_data():
|
||||
dbpedia_subkg = json.load(open(f"{DATA_ROOT}/dbpedia_subkg.json", "r", encoding="utf-8"))
|
||||
entity2id = json.load(open(f"{DATA_ROOT}/entity2id.json", "r", encoding="utf-8"))
|
||||
id2entity = {item:key for key, item in entity2id.items()}
|
||||
|
||||
side_data = []
|
||||
for k in dbpedia_subkg:
|
||||
pairs = dbpedia_subkg[k]
|
||||
for pair in pairs:
|
||||
side_data.append(pair)
|
||||
|
||||
entity_side = {}
|
||||
for a, b in tqdm(side_data):
|
||||
a = id2entity[a]
|
||||
b = id2entity[b]
|
||||
if a in entity_side:
|
||||
entity_side[a].add(b)
|
||||
else:
|
||||
entity_side[a] = set([b])
|
||||
if b in entity_side:
|
||||
entity_side[b].add(a)
|
||||
else:
|
||||
entity_side[b] = set([a])
|
||||
for a in entity_side:
|
||||
b = entity_side[a]
|
||||
entity_side[a] = list(b)
|
||||
|
||||
word_side = {}
|
||||
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||
for words in tqdm(concept_net_words.readlines()):
|
||||
a, b = words[:-1].split(" ")
|
||||
if a not in token_set or b not in token_set:
|
||||
continue
|
||||
if a in word_side:
|
||||
word_side[a].add(b)
|
||||
else:
|
||||
word_side[a] = set([b])
|
||||
if b in word_side:
|
||||
word_side[b].add(a)
|
||||
else:
|
||||
word_side[b] = set([a])
|
||||
for a in word_side:
|
||||
b = word_side[a]
|
||||
word_side[a] = list(b)
|
||||
|
||||
return entity_side, word_side
|
||||
|
||||
def redial_edger():
|
||||
item_edger, word_edger = get_side_data()
|
||||
entity_edger = item_edger
|
||||
return item_edger, entity_edger, word_edger
|
53
HyCoRec/edger/tgredial.py
Normal file
53
HyCoRec/edger/tgredial.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
DATA_ROOT = "data/dataset/tgredial/pkuseg"
|
||||
|
||||
def get_side_data():
|
||||
side_data = []
|
||||
with open(f"{DATA_ROOT}/cn-dbpedia.txt", "r", encoding="utf-8") as file:
|
||||
for line in file.readlines():
|
||||
[a, _, b] = line.split("\t")
|
||||
b = b[:-1] if b[-1] == "\n" else b
|
||||
side_data.append((a, b))
|
||||
|
||||
entity_side = {}
|
||||
for a, b in tqdm(side_data):
|
||||
if a in entity_side:
|
||||
entity_side[a].add(b)
|
||||
else:
|
||||
entity_side[a] = set([b])
|
||||
if b in entity_side:
|
||||
entity_side[b].add(a)
|
||||
else:
|
||||
entity_side[b] = set([a])
|
||||
for a in entity_side:
|
||||
b = entity_side[a]
|
||||
entity_side[a] = list(b)
|
||||
|
||||
word_side = {}
|
||||
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||
for words in tqdm(concept_net_words.readlines()):
|
||||
a, b = words[:-1].split(" ")
|
||||
if a not in token_set or b not in token_set:
|
||||
continue
|
||||
if a in word_side:
|
||||
word_side[a].add(b)
|
||||
else:
|
||||
word_side[a] = set([b])
|
||||
if b in word_side:
|
||||
word_side[b].add(a)
|
||||
else:
|
||||
word_side[b] = set([a])
|
||||
for a in word_side:
|
||||
b = word_side[a]
|
||||
word_side[a] = list(b)
|
||||
|
||||
return entity_side, word_side
|
||||
|
||||
def tgredial_edger():
|
||||
item_edger, word_edger = get_side_data()
|
||||
entity_edger = item_edger
|
||||
return item_edger, entity_edger, word_edger
|
32
HyCoRec/run_edger.py
Normal file
32
HyCoRec/run_edger.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pickle
|
||||
import argparse
|
||||
|
||||
from edger import redial_edger, tgredial_edger, opendialkg_edger, durecdial_edger
|
||||
|
||||
dataset_edger_map = {
|
||||
'redial': redial_edger,
|
||||
'tgredial': tgredial_edger,
|
||||
'opendialkg': opendialkg_edger,
|
||||
'durecdial': durecdial_edger
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--dataset', type=str, help='Dataset name')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# run edger
|
||||
dataset = args.dataset
|
||||
if dataset not in dataset_edger_map:
|
||||
raise ValueError(f"Dataset {dataset} is not supported.")
|
||||
|
||||
print(f"Running edger for dataset: {dataset}")
|
||||
item_edger, entity_edger, word_edger = dataset_edger_map[dataset]()
|
||||
|
||||
# save edger
|
||||
pickle.dump(item_edger, open(f"data/edger/{dataset}/item_edger.pkl", "wb"))
|
||||
pickle.dump(entity_edger, open(f"data/edger/{dataset}/entity_edger.pkl", "wb"))
|
||||
pickle.dump(word_edger, open(f"data/edger/{dataset}/word_edger.pkl", "wb"))
|
||||
print(f"Lengths - Item: {len(item_edger)}, Entity: {len(entity_edger)}, Word: {len(word_edger)}")
|
||||
print(f"Edger for dataset {dataset} saved successfully.")
|
Reference in New Issue
Block a user