add edger
This commit is contained in:
4
HiCore/edger/__init__.py
Normal file
4
HiCore/edger/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .redial import redial_edger
|
||||||
|
from .tgredial import tgredial_edger
|
||||||
|
from .opendialkg import opendialkg_edger
|
||||||
|
from .durecdial import durecdial_edger
|
52
HiCore/edger/durecdial.py
Normal file
52
HiCore/edger/durecdial.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
DATA_ROOT = "data/dataset/durecdial/jieba"
|
||||||
|
|
||||||
|
def get_side_data():
|
||||||
|
side_data = []
|
||||||
|
with open(f"{DATA_ROOT}/entity_subkg.txt", "r", encoding="utf-8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
[a, _, b] = line.split("\t")
|
||||||
|
b = b[:-1] if b[-1] == "\n" else b
|
||||||
|
side_data.append((a, b))
|
||||||
|
|
||||||
|
entity_side = {}
|
||||||
|
for a, b in tqdm(side_data):
|
||||||
|
if a in entity_side:
|
||||||
|
entity_side[a].add(b)
|
||||||
|
else:
|
||||||
|
entity_side[a] = set([b])
|
||||||
|
if b in entity_side:
|
||||||
|
entity_side[b].add(a)
|
||||||
|
else:
|
||||||
|
entity_side[b] = set([a])
|
||||||
|
for a in entity_side:
|
||||||
|
b = entity_side[a]
|
||||||
|
entity_side[a] = list(b)
|
||||||
|
|
||||||
|
word_side = {}
|
||||||
|
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||||
|
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||||
|
for words in tqdm(concept_net_words.readlines()):
|
||||||
|
a, b = words[:-1].split(" ")
|
||||||
|
if a not in token_set or b not in token_set:
|
||||||
|
continue
|
||||||
|
if a in word_side:
|
||||||
|
word_side[a].add(b)
|
||||||
|
else:
|
||||||
|
word_side[a] = set([b])
|
||||||
|
if b in word_side:
|
||||||
|
word_side[b].add(a)
|
||||||
|
else:
|
||||||
|
word_side[b] = set([a])
|
||||||
|
for a in word_side:
|
||||||
|
b = word_side[a]
|
||||||
|
word_side[a] = list(b)
|
||||||
|
|
||||||
|
return entity_side, word_side
|
||||||
|
|
||||||
|
def durecdial_edger():
|
||||||
|
item_edger, word_edger = get_side_data()
|
||||||
|
entity_edger = item_edger
|
||||||
|
return item_edger, entity_edger, word_edger
|
52
HiCore/edger/opendialkg.py
Normal file
52
HiCore/edger/opendialkg.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
DATA_ROOT = "data/dataset/opendialkg/nltk"
|
||||||
|
|
||||||
|
def get_side_data():
|
||||||
|
side_data = []
|
||||||
|
with open(f"{DATA_ROOT}/opendialkg_subkg.txt", "r", encoding="utf-8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
[a, _, b] = line.split("\t")
|
||||||
|
b = b[:-1] if b[-1] == "\n" else b
|
||||||
|
side_data.append((a, b))
|
||||||
|
|
||||||
|
entity_side = {}
|
||||||
|
for a, b in tqdm(side_data):
|
||||||
|
if a in entity_side:
|
||||||
|
entity_side[a].add(b)
|
||||||
|
else:
|
||||||
|
entity_side[a] = set([b])
|
||||||
|
if b in entity_side:
|
||||||
|
entity_side[b].add(a)
|
||||||
|
else:
|
||||||
|
entity_side[b] = set([a])
|
||||||
|
for a in entity_side:
|
||||||
|
b = entity_side[a]
|
||||||
|
entity_side[a] = list(b)
|
||||||
|
|
||||||
|
word_side = {}
|
||||||
|
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||||
|
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||||
|
for words in tqdm(concept_net_words.readlines()):
|
||||||
|
a, b = words[:-1].split(" ")
|
||||||
|
if a not in token_set or b not in token_set:
|
||||||
|
continue
|
||||||
|
if a in word_side:
|
||||||
|
word_side[a].add(b)
|
||||||
|
else:
|
||||||
|
word_side[a] = set([b])
|
||||||
|
if b in word_side:
|
||||||
|
word_side[b].add(a)
|
||||||
|
else:
|
||||||
|
word_side[b] = set([a])
|
||||||
|
for a in word_side:
|
||||||
|
b = word_side[a]
|
||||||
|
word_side[a] = list(b)
|
||||||
|
|
||||||
|
return entity_side, word_side
|
||||||
|
|
||||||
|
def opendialkg_edger():
|
||||||
|
item_edger, word_edger = get_side_data()
|
||||||
|
entity_edger = item_edger
|
||||||
|
return item_edger, entity_edger, word_edger
|
58
HiCore/edger/redial.py
Normal file
58
HiCore/edger/redial.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
DATA_ROOT = "data/dataset/redial/nltk"
|
||||||
|
|
||||||
|
def get_side_data():
|
||||||
|
dbpedia_subkg = json.load(open(f"{DATA_ROOT}/dbpedia_subkg.json", "r", encoding="utf-8"))
|
||||||
|
entity2id = json.load(open(f"{DATA_ROOT}/entity2id.json", "r", encoding="utf-8"))
|
||||||
|
id2entity = {item:key for key, item in entity2id.items()}
|
||||||
|
|
||||||
|
side_data = []
|
||||||
|
for k in dbpedia_subkg:
|
||||||
|
pairs = dbpedia_subkg[k]
|
||||||
|
for pair in pairs:
|
||||||
|
side_data.append(pair)
|
||||||
|
|
||||||
|
entity_side = {}
|
||||||
|
for a, b in tqdm(side_data):
|
||||||
|
a = id2entity[a]
|
||||||
|
b = id2entity[b]
|
||||||
|
if a in entity_side:
|
||||||
|
entity_side[a].add(b)
|
||||||
|
else:
|
||||||
|
entity_side[a] = set([b])
|
||||||
|
if b in entity_side:
|
||||||
|
entity_side[b].add(a)
|
||||||
|
else:
|
||||||
|
entity_side[b] = set([a])
|
||||||
|
for a in entity_side:
|
||||||
|
b = entity_side[a]
|
||||||
|
entity_side[a] = list(b)
|
||||||
|
|
||||||
|
word_side = {}
|
||||||
|
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||||
|
with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||||
|
for words in tqdm(concept_net_words.readlines()):
|
||||||
|
a, b = words[:-1].split(" ")
|
||||||
|
if a not in token_set or b not in token_set:
|
||||||
|
continue
|
||||||
|
if a in word_side:
|
||||||
|
word_side[a].add(b)
|
||||||
|
else:
|
||||||
|
word_side[a] = set([b])
|
||||||
|
if b in word_side:
|
||||||
|
word_side[b].add(a)
|
||||||
|
else:
|
||||||
|
word_side[b] = set([a])
|
||||||
|
for a in word_side:
|
||||||
|
b = word_side[a]
|
||||||
|
word_side[a] = list(b)
|
||||||
|
|
||||||
|
return entity_side, word_side
|
||||||
|
|
||||||
|
def redial_edger():
|
||||||
|
item_edger, word_edger = get_side_data()
|
||||||
|
entity_edger = item_edger
|
||||||
|
return item_edger, entity_edger, word_edger
|
53
HiCore/edger/tgredial.py
Normal file
53
HiCore/edger/tgredial.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
DATA_ROOT = "data/dataset/tgredial/pkuseg"
|
||||||
|
|
||||||
|
def get_side_data():
|
||||||
|
side_data = []
|
||||||
|
with open(f"{DATA_ROOT}/cn-dbpedia.txt", "r", encoding="utf-8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
[a, _, b] = line.split("\t")
|
||||||
|
b = b[:-1] if b[-1] == "\n" else b
|
||||||
|
side_data.append((a, b))
|
||||||
|
|
||||||
|
entity_side = {}
|
||||||
|
for a, b in tqdm(side_data):
|
||||||
|
if a in entity_side:
|
||||||
|
entity_side[a].add(b)
|
||||||
|
else:
|
||||||
|
entity_side[a] = set([b])
|
||||||
|
if b in entity_side:
|
||||||
|
entity_side[b].add(a)
|
||||||
|
else:
|
||||||
|
entity_side[b] = set([a])
|
||||||
|
for a in entity_side:
|
||||||
|
b = entity_side[a]
|
||||||
|
entity_side[a] = list(b)
|
||||||
|
|
||||||
|
word_side = {}
|
||||||
|
token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))])
|
||||||
|
with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words:
|
||||||
|
for words in tqdm(concept_net_words.readlines()):
|
||||||
|
a, b = words[:-1].split(" ")
|
||||||
|
if a not in token_set or b not in token_set:
|
||||||
|
continue
|
||||||
|
if a in word_side:
|
||||||
|
word_side[a].add(b)
|
||||||
|
else:
|
||||||
|
word_side[a] = set([b])
|
||||||
|
if b in word_side:
|
||||||
|
word_side[b].add(a)
|
||||||
|
else:
|
||||||
|
word_side[b] = set([a])
|
||||||
|
for a in word_side:
|
||||||
|
b = word_side[a]
|
||||||
|
word_side[a] = list(b)
|
||||||
|
|
||||||
|
return entity_side, word_side
|
||||||
|
|
||||||
|
def tgredial_edger():
|
||||||
|
item_edger, word_edger = get_side_data()
|
||||||
|
entity_edger = item_edger
|
||||||
|
return item_edger, entity_edger, word_edger
|
32
HiCore/run_edger.py
Normal file
32
HiCore/run_edger.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pickle
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from edger import redial_edger, tgredial_edger, opendialkg_edger, durecdial_edger
|
||||||
|
|
||||||
|
dataset_edger_map = {
|
||||||
|
'redial': redial_edger,
|
||||||
|
'tgredial': tgredial_edger,
|
||||||
|
'opendialkg': opendialkg_edger,
|
||||||
|
'durecdial': durecdial_edger
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# parse args
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-d', '--dataset', type=str, help='Dataset name')
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
# run edger
|
||||||
|
dataset = args.dataset
|
||||||
|
if dataset not in dataset_edger_map:
|
||||||
|
raise ValueError(f"Dataset {dataset} is not supported.")
|
||||||
|
|
||||||
|
print(f"Running edger for dataset: {dataset}")
|
||||||
|
item_edger, entity_edger, word_edger = dataset_edger_map[dataset]()
|
||||||
|
|
||||||
|
# save edger
|
||||||
|
pickle.dump(item_edger, open(f"data/edger/{dataset}/item_edger.pkl", "wb"))
|
||||||
|
pickle.dump(entity_edger, open(f"data/edger/{dataset}/entity_edger.pkl", "wb"))
|
||||||
|
pickle.dump(word_edger, open(f"data/edger/{dataset}/word_edger.pkl", "wb"))
|
||||||
|
print(f"Lengths - Item: {len(item_edger)}, Entity: {len(entity_edger)}, Word: {len(word_edger)}")
|
||||||
|
print(f"Edger for dataset {dataset} saved successfully.")
|
37
README.md
37
README.md
@@ -0,0 +1,37 @@
|
|||||||
|
# HiCore
|
||||||
|
|
||||||
|
The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Our paper can be viewed at [here](https://aclanthology.org/2024.emnlp-main.86/)
|
||||||
|
|
||||||
|
## Python venv
|
||||||
|
|
||||||
|
We use `uv` to manage HiCore's python venv. You can click this [url](https://docs.astral.sh/uv/) for more details about `uv`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv venv --python 3.12
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
The dataset will be automatically download after you run the repo's code. However, the item/entity/word edger should be built by followed command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd HiCore/
|
||||||
|
uv run run_edger.py -d redial
|
||||||
|
```
|
||||||
|
|
||||||
|
Or download the whole dataset and item/entity/word edger from [here](https://drive.tokisakix.cn/share/9hQIhokW)
|
||||||
|
|
||||||
|
Place the dataset in path `HiCore/data`.
|
||||||
|
|
||||||
|
## How to run
|
||||||
|
|
||||||
|
Run the crslab framework by followed command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd HiCore/
|
||||||
|
uv run run_crslab.py -c config/crs/hicore/redial.yaml -g 0 -s 3407
|
||||||
|
```
|
BIN
assets/hicore.png
Normal file
BIN
assets/hicore.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 761 KiB |
@@ -1,5 +1,5 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "hicore"
|
name = "HiCore"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024)"
|
description = "The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024)"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@@ -18,6 +18,16 @@ dependencies = [
|
|||||||
"transformers>=4.55.0",
|
"transformers>=4.55.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cu129", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||||
|
]
|
||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
url="https://pypi.tuna.tsinghua.edu.cn/simple"
|
url="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
default=true
|
default=true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu129"
|
||||||
|
url = "https://download.pytorch.org/whl/cu129"
|
||||||
|
explicit = true
|
Reference in New Issue
Block a user