diff --git a/HiCore/edger/__init__.py b/HiCore/edger/__init__.py new file mode 100644 index 0000000..2f53e14 --- /dev/null +++ b/HiCore/edger/__init__.py @@ -0,0 +1,4 @@ +from .redial import redial_edger +from .tgredial import tgredial_edger +from .opendialkg import opendialkg_edger +from .durecdial import durecdial_edger \ No newline at end of file diff --git a/HiCore/edger/durecdial.py b/HiCore/edger/durecdial.py new file mode 100644 index 0000000..17bb9d8 --- /dev/null +++ b/HiCore/edger/durecdial.py @@ -0,0 +1,52 @@ +import json +from tqdm import tqdm + +DATA_ROOT = "data/dataset/durecdial/jieba" + +def get_side_data(): + side_data = [] + with open(f"{DATA_ROOT}/entity_subkg.txt", "r", encoding="utf-8") as file: + for line in file.readlines(): + [a, _, b] = line.split("\t") + b = b[:-1] if b[-1] == "\n" else b + side_data.append((a, b)) + + entity_side = {} + for a, b in tqdm(side_data): + if a in entity_side: + entity_side[a].add(b) + else: + entity_side[a] = set([b]) + if b in entity_side: + entity_side[b].add(a) + else: + entity_side[b] = set([a]) + for a in entity_side: + b = entity_side[a] + entity_side[a] = list(b) + + word_side = {} + token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))]) + with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words: + for words in tqdm(concept_net_words.readlines()): + a, b = words[:-1].split(" ") + if a not in token_set or b not in token_set: + continue + if a in word_side: + word_side[a].add(b) + else: + word_side[a] = set([b]) + if b in word_side: + word_side[b].add(a) + else: + word_side[b] = set([a]) + for a in word_side: + b = word_side[a] + word_side[a] = list(b) + + return entity_side, word_side + +def durecdial_edger(): + item_edger, word_edger = get_side_data() + entity_edger = item_edger + return item_edger, entity_edger, word_edger \ No newline at end of file diff --git a/HiCore/edger/opendialkg.py b/HiCore/edger/opendialkg.py new file mode 100644 index 0000000..ddfaae5 --- /dev/null +++ b/HiCore/edger/opendialkg.py @@ -0,0 +1,52 @@ +import json +from tqdm import tqdm + +DATA_ROOT = "data/dataset/opendialkg/nltk" + +def get_side_data(): + side_data = [] + with open(f"{DATA_ROOT}/opendialkg_subkg.txt", "r", encoding="utf-8") as file: + for line in file.readlines(): + [a, _, b] = line.split("\t") + b = b[:-1] if b[-1] == "\n" else b + side_data.append((a, b)) + + entity_side = {} + for a, b in tqdm(side_data): + if a in entity_side: + entity_side[a].add(b) + else: + entity_side[a] = set([b]) + if b in entity_side: + entity_side[b].add(a) + else: + entity_side[b] = set([a]) + for a in entity_side: + b = entity_side[a] + entity_side[a] = list(b) + + word_side = {} + token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))]) + with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words: + for words in tqdm(concept_net_words.readlines()): + a, b = words[:-1].split(" ") + if a not in token_set or b not in token_set: + continue + if a in word_side: + word_side[a].add(b) + else: + word_side[a] = set([b]) + if b in word_side: + word_side[b].add(a) + else: + word_side[b] = set([a]) + for a in word_side: + b = word_side[a] + word_side[a] = list(b) + + return entity_side, word_side + +def opendialkg_edger(): + item_edger, word_edger = get_side_data() + entity_edger = item_edger + return item_edger, entity_edger, word_edger \ No newline at end of file diff --git a/HiCore/edger/redial.py b/HiCore/edger/redial.py new file mode 100644 index 0000000..acfd305 --- /dev/null +++ b/HiCore/edger/redial.py @@ -0,0 +1,58 @@ +import json +import pickle +from tqdm import tqdm + +DATA_ROOT = "data/dataset/redial/nltk" + +def get_side_data(): + dbpedia_subkg = json.load(open(f"{DATA_ROOT}/dbpedia_subkg.json", "r", encoding="utf-8")) + entity2id = json.load(open(f"{DATA_ROOT}/entity2id.json", "r", encoding="utf-8")) + id2entity = {item:key for key, item in entity2id.items()} + + side_data = [] + for k in dbpedia_subkg: + pairs = dbpedia_subkg[k] + for pair in pairs: + side_data.append(pair) + + entity_side = {} + for a, b in tqdm(side_data): + a = id2entity[a] + b = id2entity[b] + if a in entity_side: + entity_side[a].add(b) + else: + entity_side[a] = set([b]) + if b in entity_side: + entity_side[b].add(a) + else: + entity_side[b] = set([a]) + for a in entity_side: + b = entity_side[a] + entity_side[a] = list(b) + + word_side = {} + token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))]) + with(open("data/conceptnet/en_side.txt", "r", encoding="utf-8")) as concept_net_words: + for words in tqdm(concept_net_words.readlines()): + a, b = words[:-1].split(" ") + if a not in token_set or b not in token_set: + continue + if a in word_side: + word_side[a].add(b) + else: + word_side[a] = set([b]) + if b in word_side: + word_side[b].add(a) + else: + word_side[b] = set([a]) + for a in word_side: + b = word_side[a] + word_side[a] = list(b) + + return entity_side, word_side + +def redial_edger(): + item_edger, word_edger = get_side_data() + entity_edger = item_edger + return item_edger, entity_edger, word_edger \ No newline at end of file diff --git a/HiCore/edger/tgredial.py b/HiCore/edger/tgredial.py new file mode 100644 index 0000000..62615bc --- /dev/null +++ b/HiCore/edger/tgredial.py @@ -0,0 +1,53 @@ +import json +import pickle +from tqdm import tqdm + +DATA_ROOT = "data/dataset/tgredial/pkuseg" + +def get_side_data(): + side_data = [] + with open(f"{DATA_ROOT}/cn-dbpedia.txt", "r", encoding="utf-8") as file: + for line in file.readlines(): + [a, _, b] = line.split("\t") + b = b[:-1] if b[-1] == "\n" else b + side_data.append((a, b)) + + entity_side = {} + for a, b in tqdm(side_data): + if a in entity_side: + entity_side[a].add(b) + else: + entity_side[a] = set([b]) + if b in entity_side: + entity_side[b].add(a) + else: + entity_side[b] = set([a]) + for a in entity_side: + b = entity_side[a] + entity_side[a] = list(b) + + word_side = {} + token_set = set([token.lower() for token in json.load(open(f"{DATA_ROOT}/token2id.json", "r", encoding="utf-8"))]) + with(open("data/conceptnet/zh_side.txt", "r", encoding="utf-8")) as concept_net_words: + for words in tqdm(concept_net_words.readlines()): + a, b = words[:-1].split(" ") + if a not in token_set or b not in token_set: + continue + if a in word_side: + word_side[a].add(b) + else: + word_side[a] = set([b]) + if b in word_side: + word_side[b].add(a) + else: + word_side[b] = set([a]) + for a in word_side: + b = word_side[a] + word_side[a] = list(b) + + return entity_side, word_side + +def tgredial_edger(): + item_edger, word_edger = get_side_data() + entity_edger = item_edger + return item_edger, entity_edger, word_edger \ No newline at end of file diff --git a/HiCore/run_edger.py b/HiCore/run_edger.py new file mode 100644 index 0000000..4e8191b --- /dev/null +++ b/HiCore/run_edger.py @@ -0,0 +1,32 @@ +import pickle +import argparse + +from edger import redial_edger, tgredial_edger, opendialkg_edger, durecdial_edger + +dataset_edger_map = { + 'redial': redial_edger, + 'tgredial': tgredial_edger, + 'opendialkg': opendialkg_edger, + 'durecdial': durecdial_edger +} + +if __name__ == '__main__': + # parse args + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--dataset', type=str, help='Dataset name') + args, _ = parser.parse_known_args() + + # run edger + dataset = args.dataset + if dataset not in dataset_edger_map: + raise ValueError(f"Dataset {dataset} is not supported.") + + print(f"Running edger for dataset: {dataset}") + item_edger, entity_edger, word_edger = dataset_edger_map[dataset]() + + # save edger + pickle.dump(item_edger, open(f"data/edger/{dataset}/item_edger.pkl", "wb")) + pickle.dump(entity_edger, open(f"data/edger/{dataset}/entity_edger.pkl", "wb")) + pickle.dump(word_edger, open(f"data/edger/{dataset}/word_edger.pkl", "wb")) + print(f"Lengths - Item: {len(item_edger)}, Entity: {len(entity_edger)}, Word: {len(word_edger)}") + print(f"Edger for dataset {dataset} saved successfully.") \ No newline at end of file diff --git a/README.md b/README.md index e69de29..f5f6c70 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,37 @@ +# HiCore + +The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024) + +![hicore](assets/hicore.png) + +Our paper can be viewed at [here](https://aclanthology.org/2024.emnlp-main.86/) + +## Python venv + +We use `uv` to manage HiCore's python venv. You can click this [url](https://docs.astral.sh/uv/) for more details about `uv`. + +```bash +uv venv --python 3.12 +``` + +## Dataset + +The dataset will be automatically download after you run the repo's code. However, the item/entity/word edger should be built by followed command: + +```bash +cd HiCore/ +uv run run_edger.py -d redial +``` + +Or download the whole dataset and item/entity/word edger from [here](https://drive.tokisakix.cn/share/9hQIhokW) + +Place the dataset in path `HiCore/data`. + +## How to run + +Run the crslab framework by followed command: + +```bash +cd HiCore/ +uv run run_crslab.py -c config/crs/hicore/redial.yaml -g 0 -s 3407 +``` \ No newline at end of file diff --git a/assets/hicore.png b/assets/hicore.png new file mode 100644 index 0000000..aab617a Binary files /dev/null and b/assets/hicore.png differ diff --git a/pyproject.toml b/pyproject.toml index f51b477..b679bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "hicore" +name = "HiCore" version = "0.1.0" description = "The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024)" readme = "README.md" @@ -18,6 +18,16 @@ dependencies = [ "transformers>=4.55.0", ] +[tool.uv.sources] +torch = [ + { index = "pytorch-cu129", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] + [[tool.uv.index]] url="https://pypi.tuna.tsinghua.edu.cn/simple" default=true + +[[tool.uv.index]] +name = "pytorch-cu129" +url = "https://download.pytorch.org/whl/cu129" +explicit = true \ No newline at end of file