commit 9a87050deed9eab5d1ae53a8b2217f851e044f56 Author: Tokisakix <2116884726@qq.com> Date: Thu Aug 7 00:35:07 2025 +0800 init HiCore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c482d67 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +*.pyc +*.json +*.pkl +*.txt +*.log +.built +*.npy +*.npz + +res/ +log/ + +.python-version +uv.lock \ No newline at end of file diff --git a/HiCore/config/crs/hicore/durecdial.yaml b/HiCore/config/crs/hicore/durecdial.yaml new file mode 100644 index 0000000..60c46c2 --- /dev/null +++ b/HiCore/config/crs/hicore/durecdial.yaml @@ -0,0 +1,50 @@ +# dataset +dataset: DuRecDial +tokenize: jieba +# dataloader +related_truncate: 1024 +context_truncate: 256 +response_truncate: 30 +scale: 1 +# model +model: HiCore +token_emb_dim: 300 +kg_emb_dim: 128 +num_bases: 8 +n_heads: 2 +n_layers: 2 +ffn_size: 300 +dropout: 0.1 +attention_dropout: 0.0 +relu_dropout: 0.1 +learn_positional_embeddings: false +embeddings_scale: true +reduction: false +n_positions: 1024 +user_proj_dim: 512 +# HiCore-CHANGE +mha_n_heads: 4 +pooling: Attn +extension_strategy: Adaptive +# optim +rec: + epoch: 1 + batch_size: 64 + early_stop: False + stop_mode: min + impatience: 2 + optimizer: + name: Adam + lr: !!float 1e-3 +conv: + epoch: 1 + batch_size: 128 + impatience: 1 + optimizer: + name: Adam + lr: !!float 1e-3 + lr_scheduler: + name: ReduceLROnPlateau + patience: 3 + factor: 0.5 + gradient_clip: 0.1 \ No newline at end of file diff --git a/HiCore/config/crs/hicore/opendialkg.yaml b/HiCore/config/crs/hicore/opendialkg.yaml new file mode 100644 index 0000000..7643f9c --- /dev/null +++ b/HiCore/config/crs/hicore/opendialkg.yaml @@ -0,0 +1,50 @@ +# dataset +dataset: OpenDialKG +tokenize: nltk +# dataloader +related_truncate: 1024 +context_truncate: 256 +response_truncate: 30 +scale: 1 +# model +model: HiCore +token_emb_dim: 300 +kg_emb_dim: 128 +num_bases: 8 +n_heads: 2 +n_layers: 2 +ffn_size: 300 +dropout: 0.1 +attention_dropout: 0.0 +relu_dropout: 0.1 +learn_positional_embeddings: false +embeddings_scale: true +reduction: false +n_positions: 1024 +user_proj_dim: 512 +# HiCore-CHANGE +mha_n_heads: 4 +pooling: Mean +extension_strategy: Adaptive +# optim +rec: + epoch: 1 + batch_size: 256 + early_stop: False + stop_mode: min + impatience: 2 + optimizer: + name: Adam + lr: !!float 1e-3 +conv: + epoch: 1 + batch_size: 128 + impatience: 1 + optimizer: + name: Adam + lr: !!float 1e-3 + lr_scheduler: + name: ReduceLROnPlateau + patience: 3 + factor: 0.5 + gradient_clip: 0.1 \ No newline at end of file diff --git a/HiCore/config/crs/hicore/redial.yaml b/HiCore/config/crs/hicore/redial.yaml new file mode 100644 index 0000000..cce966d --- /dev/null +++ b/HiCore/config/crs/hicore/redial.yaml @@ -0,0 +1,50 @@ +# dataset +dataset: ReDial +tokenize: nltk +# dataloader +related_truncate: 1024 +context_truncate: 256 +response_truncate: 30 +scale: 1 +# model +model: HiCore +token_emb_dim: 300 +kg_emb_dim: 128 +num_bases: 8 +n_heads: 2 +n_layers: 2 +ffn_size: 300 +dropout: 0.1 +attention_dropout: 0.0 +relu_dropout: 0.1 +learn_positional_embeddings: false +embeddings_scale: true +reduction: false +n_positions: 1024 +user_proj_dim: 512 +# HiCore-CHANGE +mha_n_heads: 4 +pooling: Mean +extension_strategy: Adaptive +# optim +rec: + epoch: 1 + batch_size: 256 + early_stop: False + stop_mode: min + impatience: 2 + optimizer: + name: Adam + lr: !!float 1e-3 +conv: + epoch: 1 + batch_size: 128 + impatience: 1 + optimizer: + name: Adam + lr: !!float 1e-3 + lr_scheduler: + name: ReduceLROnPlateau + patience: 3 + factor: 0.5 + gradient_clip: 0.1 \ No newline at end of file diff --git a/HiCore/config/crs/hicore/tgredial.yaml b/HiCore/config/crs/hicore/tgredial.yaml new file mode 100644 index 0000000..1be7aeb --- /dev/null +++ b/HiCore/config/crs/hicore/tgredial.yaml @@ -0,0 +1,50 @@ +# dataset +dataset: TGReDial +tokenize: pkuseg +# dataloader +related_truncate: 1024 +context_truncate: 256 +response_truncate: 30 +scale: 1 +# model +model: HiCore +token_emb_dim: 300 +kg_emb_dim: 128 +num_bases: 8 +n_heads: 2 +n_layers: 2 +ffn_size: 300 +dropout: 0.1 +attention_dropout: 0.0 +relu_dropout: 0.1 +learn_positional_embeddings: false +embeddings_scale: true +reduction: false +n_positions: 1024 +user_proj_dim: 512 +# HiCore-CHANGE +mha_n_heads: 4 +pooling: Attn +extension_strategy: Adaptive +# optim +rec: + epoch: 1 + batch_size: 64 + early_stop: False + stop_mode: min + impatience: 2 + optimizer: + name: Adam + lr: !!float 1e-3 +conv: + epoch: 1 + batch_size: 128 + impatience: 1 + optimizer: + name: Adam + lr: !!float 1e-3 + lr_scheduler: + name: ReduceLROnPlateau + patience: 3 + factor: 0.5 + gradient_clip: 0.1 \ No newline at end of file diff --git a/HiCore/crslab/__init__.py b/HiCore/crslab/__init__.py new file mode 100644 index 0000000..b794fd4 --- /dev/null +++ b/HiCore/crslab/__init__.py @@ -0,0 +1 @@ +__version__ = '0.1.0' diff --git a/HiCore/crslab/config/__init__.py b/HiCore/crslab/config/__init__.py new file mode 100644 index 0000000..1bdfef0 --- /dev/null +++ b/HiCore/crslab/config/__init__.py @@ -0,0 +1,32 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/29 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +"""Config module which loads parameters for the whole system. + +Attributes: + SAVE_PATH (str): where system to save. + DATASET_PATH (str): where dataset to save. + MODEL_PATH (str): where model related data to save. + PRETRAIN_PATH (str): where pretrained model to save. + EMBEDDING_PATH (str): where pretrained embedding to save, used for evaluate embedding related metrics. +""" + +import os +from os.path import dirname, realpath + +from .config import Config + +ROOT_PATH = dirname(dirname(dirname(realpath(__file__)))) +SAVE_PATH = os.path.join(ROOT_PATH, 'save') +DATA_PATH = os.path.join(ROOT_PATH, 'data') +DATASET_PATH = os.path.join(DATA_PATH, 'dataset') +MODEL_PATH = os.path.join(DATA_PATH, 'model') +PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain') +EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding') \ No newline at end of file diff --git a/HiCore/crslab/config/config.py b/HiCore/crslab/config/config.py new file mode 100644 index 0000000..7d8ee22 --- /dev/null +++ b/HiCore/crslab/config/config.py @@ -0,0 +1,155 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2021/1/9 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +import json +import os +import time +from pprint import pprint + +import yaml +from loguru import logger +from tqdm import tqdm +import random +import numpy as np +import torch + + +class Config: + """Configurator module that load the defined parameters.""" + + def __init__(self, config_file, gpu='-1', debug=False, seed=2020, pretrain=False, pretrain_epoch=None): + """Load parameters and set log level. + + Args: + config_file (str): path to the config file, which should be in ``yaml`` format. + You can use default config provided in the `Github repo`_, or write it by yourself. + debug (bool, optional): whether to enable debug function during running. Defaults to False. + + .. _Github repo: + https://github.com/RUCAIBox/CRSLab + + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + self.opt = self.load_yaml_configs(config_file) + # pretrain epoch + self.opt['pretrain'] = pretrain + self.opt['pretrain_epoch'] = pretrain_epoch + # gpu + os.environ['CUDA_VISIBLE_DEVICES'] = gpu + self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] if gpu != '-1' else [-1] + # dataset + dataset = self.opt['dataset'] + tokenize = self.opt['tokenize'] + if isinstance(tokenize, dict): + tokenize = ', '.join(tokenize.values()) + # model + model = self.opt.get('model', None) + rec_model = self.opt.get('rec_model', None) + conv_model = self.opt.get('conv_model', None) + policy_model = self.opt.get('policy_model', None) + if model: + model_name = model + else: + models = [] + if rec_model: + models.append(rec_model) + if conv_model: + models.append(conv_model) + if policy_model: + models.append(policy_model) + model_name = '_'.join(models) + self.opt['model_name'] = model_name + self.opt['rankfile'] = f"{model_name}-{dataset}.json" + # log + log_name = self.opt.get("log_name", dataset + '_' + model_name + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", + time.localtime())) + ".log" + if not os.path.exists("log"): + os.makedirs("log") + logger.remove() + if debug: + level = 'DEBUG' + else: + level = 'INFO' + logger.add(os.path.join("log", log_name), level=level) + logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level) + + logger.info(f"[Dataset: {dataset} tokenized in {tokenize}]") + if model: + logger.info(f'[Model: {model}]') + if rec_model: + logger.info(f'[Recommendation Model: {rec_model}]') + if conv_model: + logger.info(f'[Conversation Model: {conv_model}]') + if policy_model: + logger.info(f'[Policy Model: {policy_model}]') + logger.info("[Config]" + '/n' + json.dumps(self.opt, indent=4)) + + @staticmethod + def load_yaml_configs(filename): + """This function reads ``yaml`` file to build config dictionary + + Args: + filename (str): path to ``yaml`` config + + Returns: + dict: config + + """ + config_dict = dict() + with open(filename, 'r', encoding='utf-8') as f: + config_dict.update(yaml.safe_load(f.read())) + return config_dict + + def __setitem__(self, key, value): + if not isinstance(key, str): + raise TypeError("index must be a str.") + self.opt[key] = value + + def __getitem__(self, item): + if item in self.opt: + return self.opt[item] + else: + return None + + def get(self, item, default=None): + """Get value of corrsponding item in config + + Args: + item (str): key to query in config + default (optional): default value for item if not found in config. Defaults to None. + + Returns: + value of corrsponding item in config + + """ + if item in self.opt: + return self.opt[item] + else: + return default + + def __contains__(self, key): + if not isinstance(key, str): + raise TypeError("index must be a str.") + return key in self.opt + + def __str__(self): + return str(self.opt) + + def __repr__(self): + return self.__str__() + + +if __name__ == '__main__': + opt_dict = Config('config/crs/hicore/hredial.yaml') + pprint(opt_dict) diff --git a/HiCore/crslab/data/__init__.py b/HiCore/crslab/data/__init__.py new file mode 100644 index 0000000..ecc29ae --- /dev/null +++ b/HiCore/crslab/data/__init__.py @@ -0,0 +1,85 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/29, 2020/12/17 +# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com + +# @Time : 2021/10/06 +# @Author : Zhipeng Zhao +# @Email : oran_official@outlook.com + +"""Data module which reads, processes and batches data for the whole system + +Attributes: + dataset_register_table (dict): record all supported dataset + dataset_language_map (dict): record all dataset corresponding language + dataloader_register_table (dict): record all model corresponding dataloader + +""" + +from crslab.data.dataloader import * +from crslab.data.dataset import * + +dataset_register_table = { + 'HReDial': HReDialDataset, + 'HTGReDial': HTGReDialDataset, + 'OpenDialKG': OpenDialKGDataset, + 'DuRecDial': DuRecDialDataset, + 'ReDial': ReDialDataset, + 'TGReDial': TGReDialDataset, +} + +dataset_language_map = { + 'ReDial': 'en', + 'TGReDial': 'zh', + 'HReDial': 'en', + 'HTGReDial': 'zh', + 'OpenDialKG': 'en', + 'DuRecDial': 'zh', +} + +dataloader_register_table = { + 'HiCore': HiCoreDataLoader, +} + + +def get_dataset(opt, tokenize, restore, save) -> BaseDataset: + """get and process dataset + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize the dataset. + restore (bool): whether to restore saved dataset which has been processed. + save (bool): whether to save dataset after processing. + + Returns: + processed dataset + + """ + dataset = opt['dataset'] + if dataset in dataset_register_table: + return dataset_register_table[dataset](opt, tokenize, restore, save) + else: + raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented') + + +def get_dataloader(opt, dataset, vocab) -> BaseDataLoader: + """get dataloader to batchify dataset + + Args: + opt (Config or dict): config for dataloader or the whole system. + dataset: processed raw data, no side data. + vocab (dict): all kinds of useful size, idx and map between token and idx. + + Returns: + dataloader + + """ + model_name = opt['model_name'] + if model_name in dataloader_register_table: + return dataloader_register_table[model_name](opt, dataset, vocab) + else: + raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented') diff --git a/HiCore/crslab/data/dataloader/__init__.py b/HiCore/crslab/data/dataloader/__init__.py new file mode 100644 index 0000000..ad50826 --- /dev/null +++ b/HiCore/crslab/data/dataloader/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseDataLoader +from .hicore import HiCoreDataLoader diff --git a/HiCore/crslab/data/dataloader/base.py b/HiCore/crslab/data/dataloader/base.py new file mode 100644 index 0000000..1a33357 --- /dev/null +++ b/HiCore/crslab/data/dataloader/base.py @@ -0,0 +1,211 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2020/12/29 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +import random +from abc import ABC + +from loguru import logger +from math import ceil +from tqdm import tqdm + + +class BaseDataLoader(ABC): + """Abstract class of dataloader + + Notes: + ``'scale'`` can be set in config to limit the size of dataset. + + """ + + def __init__(self, opt, dataset): + """ + Args: + opt (Config or dict): config for dataloader or the whole system. + dataset: dataset + + """ + self.opt = opt + self.dataset = dataset + self.scale = opt.get('scale', 1) + assert 0 < self.scale <= 1 + + def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None): + """Collate batch data for system to fit + + Args: + batch_fn (func): function to collate data + batch_size (int): + shuffle (bool, optional): Defaults to True. + process_fn (func, optional): function to process dataset before batchify. Defaults to None. + + Yields: + tuple or dict of torch.Tensor: batch data for system to fit + + """ + dataset = self.dataset + if process_fn is not None: + dataset = process_fn() + # logger.info('[Finish dataset process before batchify]') + dataset = dataset[:ceil(len(dataset) * self.scale)] + logger.debug(f'[Dataset size: {len(dataset)}]') + + batch_num = ceil(len(dataset) / batch_size) + idx_list = list(range(len(dataset))) + if shuffle: + random.shuffle(idx_list) + + for start_idx in tqdm(range(batch_num)): + batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size] + batch = [dataset[idx] for idx in batch_idx] + batch = batch_fn(batch) + if batch == False: + continue + else: + # print(batch) + yield(batch) + + def get_conv_data(self, batch_size, shuffle=True): + """get_data wrapper for conversation. + + You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``. + + Args: + batch_size (int): + shuffle (bool, optional): Defaults to True. + + Yields: + tuple or dict of torch.Tensor: batch data for conversation. + + """ + return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn) + + def get_rec_data(self, batch_size, shuffle=True): + """get_data wrapper for recommendation. + + You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``. + + Args: + batch_size (int): + shuffle (bool, optional): Defaults to True. + + Yields: + tuple or dict of torch.Tensor: batch data for recommendation. + + """ + return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn) + + def get_policy_data(self, batch_size, shuffle=True): + """get_data wrapper for policy. + + You can implement your own process_fn in ``self.policy_process_fn``, batch_fn in ``policy_batchify``. + + Args: + batch_size (int): + shuffle (bool, optional): Defaults to True. + + Yields: + tuple or dict of torch.Tensor: batch data for policy. + + """ + return self.get_data(self.policy_batchify, batch_size, shuffle, self.policy_process_fn) + + def conv_process_fn(self): + """Process whole data for conversation before batch_fn. + + Returns: + processed dataset. Defaults to return the same as `self.dataset`. + + """ + return self.dataset + + def conv_batchify(self, batch): + """batchify data for conversation after process. + + Args: + batch (list): processed batch dataset. + + Returns: + batch data for the system to train conversation part. + """ + raise NotImplementedError('dataloader must implement conv_batchify() method') + + def rec_process_fn(self): + """Process whole data for recommendation before batch_fn. + + Returns: + processed dataset. Defaults to return the same as `self.dataset`. + + """ + return self.dataset + + def rec_batchify(self, batch): + """batchify data for recommendation after process. + + Args: + batch (list): processed batch dataset. + + Returns: + batch data for the system to train recommendation part. + """ + raise NotImplementedError('dataloader must implement rec_batchify() method') + + def policy_process_fn(self): + """Process whole data for policy before batch_fn. + + Returns: + processed dataset. Defaults to return the same as `self.dataset`. + + """ + return self.dataset + + def policy_batchify(self, batch): + """batchify data for policy after process. + + Args: + batch (list): processed batch dataset. + + Returns: + batch data for the system to train policy part. + """ + raise NotImplementedError('dataloader must implement policy_batchify() method') + + def retain_recommender_target(self): + """keep data whose role is recommender. + + Returns: + Recommender part of ``self.dataset``. + + """ + dataset = [] + for conv_dict in tqdm(self.dataset): + if conv_dict['role'] == 'Recommender': + dataset.append(conv_dict) + return dataset + + def rec_interact(self, data): + """process user input data for system to recommend. + + Args: + data: user input data. + + Returns: + data for system to recommend. + """ + pass + + def conv_interact(self, data): + """Process user input data for system to converse. + + Args: + data: user input data. + + Returns: + data for system in converse. + """ + pass diff --git a/HiCore/crslab/data/dataloader/hicore.py b/HiCore/crslab/data/dataloader/hicore.py new file mode 100644 index 0000000..ff29bd8 --- /dev/null +++ b/HiCore/crslab/data/dataloader/hicore.py @@ -0,0 +1,144 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/5/26 +# @Author : Chenzhan Shang +# @email : czshang@outlook.com + +import pickle +import torch +from tqdm import tqdm + +from crslab.data.dataloader.base import BaseDataLoader +from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt + + +class HiCoreDataLoader(BaseDataLoader): + """Dataloader for model KBRD. + + Notes: + You can set the following parameters in config: + + - ``"context_truncate"``: the maximum length of context. + - ``"response_truncate"``: the maximum length of response. + - ``"entity_truncate"``: the maximum length of mentioned entities in context. + + The following values must be specified in ``vocab``: + + - ``"pad"`` + - ``"start"`` + - ``"end"`` + - ``"pad_entity"`` + + the above values specify the id of needed special token. + + """ + + def __init__(self, opt, dataset, vocab): + """ + + Args: + opt (Config or dict): config for dataloader or the whole system. + dataset: data for model. + vocab (dict): all kinds of useful size, idx and map between token and idx. + + """ + super().__init__(opt, dataset) + self.pad_token_idx = vocab["tok2ind"]["__pad__"] + self.start_token_idx = vocab["tok2ind"]["__start__"] + self.end_token_idx = vocab["tok2ind"]["__end__"] + self.split_token_idx = vocab["tok2ind"].get("_split_", None) + self.related_truncate = opt.get("related_truncate", None) + self.context_truncate = opt.get("context_truncate", None) + self.response_truncate = opt.get("response_truncate", None) + self.entity_truncate = opt.get("entity_truncate", None) + self.review_entity2id = vocab["entity2id"] + return + + def rec_process_fn(self): + augment_dataset = [] + for conv_dict in tqdm(self.dataset): + if conv_dict["role"] == "Recommender": + for item in conv_dict["items"]: + augment_conv_dict = { + "conv_id": conv_dict["conv_id"], + "related_item": conv_dict["item"], + "related_entity": conv_dict["entity"], + "related_word": conv_dict["word"], + "item": item, + } + augment_dataset.append(augment_conv_dict) + + return augment_dataset + + def rec_batchify(self, batch): + batch_related_item = [] + batch_related_entity = [] + batch_related_word = [] + batch_movies = [] + batch_conv_id = [] + for conv_dict in batch: + batch_related_item.append(conv_dict["related_item"]) + batch_related_entity.append(conv_dict["related_entity"]) + batch_related_word.append(conv_dict["related_word"]) + batch_movies.append(conv_dict["item"]) + batch_conv_id.append(conv_dict["conv_id"]) + + res = { + "conv_id": batch_conv_id, + "related_item": batch_related_item, + "related_entity": batch_related_entity, + "related_word": batch_related_word, + "item": torch.tensor(batch_movies, dtype=torch.long), + } + + return res + + def conv_process_fn(self, *args, **kwargs): + return self.retain_recommender_target() + + def conv_batchify(self, batch): + batch_related_tokens = [] + batch_context_tokens = [] + + batch_related_item = [] + batch_related_entity = [] + batch_related_word = [] + + batch_response = [] + batch_conv_id = [] + for conv_dict in batch: + batch_related_tokens.append( + truncate(conv_dict["tokens"][-1], self.related_truncate, truncate_tail=False) + ) + batch_context_tokens.append( + truncate(merge_utt( + conv_dict["tokens"], + start_token_idx=self.start_token_idx, + split_token_idx=self.split_token_idx, + final_token_idx=self.end_token_idx + ), self.context_truncate, truncate_tail=False) + ) + + batch_related_item.append(conv_dict["item"]) + batch_related_entity.append(conv_dict["entity"]) + batch_related_word.append(conv_dict["word"]) + + batch_response.append( + add_start_end_token_idx(truncate(conv_dict["response"], self.response_truncate - 2), + start_token_idx=self.start_token_idx, + end_token_idx=self.end_token_idx)) + batch_conv_id.append(conv_dict["conv_id"]) + + res = { + "related_tokens": padded_tensor(batch_related_tokens, self.pad_token_idx, pad_tail=False), + "context_tokens": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), + "related_item": batch_related_item, + "related_entity": batch_related_entity, + "related_word": batch_related_word, + "response": padded_tensor(batch_response, self.pad_token_idx), + "conv_id": batch_conv_id, + } + + return res + + def policy_batchify(self, *args, **kwargs): + pass diff --git a/HiCore/crslab/data/dataloader/utils.py b/HiCore/crslab/data/dataloader/utils.py new file mode 100644 index 0000000..157d62a --- /dev/null +++ b/HiCore/crslab/data/dataloader/utils.py @@ -0,0 +1,182 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/10 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/20, 2020/12/15 +# @Author : Xiaolei Wang, Yuanhang Zhou +# @email : wxl1999@foxmail.com, sdzyh002@gmail + +# UPDATE +# @Time : 2021/10/06 +# @Author : Zhipeng Zhao +# @Email : oran_official@outlook.com + + +from copy import copy + +import torch +from typing import List, Union, Optional + + +def padded_tensor( + items: List[Union[List[int], torch.LongTensor]], + pad_idx: int = 0, + pad_tail: bool = True, + max_len: Optional[int] = None, +) -> torch.LongTensor: + """Create a padded matrix from an uneven list of lists. + + Returns padded matrix. + + Matrix is right-padded (filled to the right) by default, but can be + left padded if the flag is set to True. + + Matrix can also be placed on cuda automatically. + + :param list[iter[int]] items: List of items + :param int pad_idx: the value to use for padding + :param bool pad_tail: + :param int max_len: if None, the max length is the maximum item length + + :returns: padded tensor. + :rtype: Tensor[int64] + + """ + # number of items + n = len(items) + # length of each item + lens: List[int] = [len(item) for item in items] # type: ignore + # max in time dimension + t = max(lens) if max_len is None else max_len + # if input tensors are empty, we should expand to nulls + t = max(t, 1) + + if isinstance(items[0], torch.Tensor): + # keep type of input tensors, they may already be cuda ones + output = items[0].new(n, t) # type: ignore + else: + output = torch.LongTensor(n, t) # type: ignore + output.fill_(pad_idx) + + for i, (item, length) in enumerate(zip(items, lens)): + if length == 0: + # skip empty items + continue + if not isinstance(item, torch.Tensor): + # put non-tensors into a tensor + item = torch.tensor(item, dtype=torch.long) # type: ignore + if pad_tail: + # place at beginning + output[i, :length] = item + else: + # place at end + output[i, t - length:] = item + + return output + + +def get_onehot(data_list, categories) -> torch.Tensor: + """Transform lists of label into one-hot. + + Args: + data_list (list of list of int): source data. + categories (int): #label class. + + Returns: + torch.Tensor: one-hot labels. + + """ + onehot_labels = [] + for label_list in data_list: + onehot_label = torch.zeros(categories) + for label in label_list: + onehot_label[label] = 1.0 / len(label_list) + onehot_labels.append(onehot_label) + return torch.stack(onehot_labels, dim=0) + + +def add_start_end_token_idx(vec: list, start_token_idx: int = None, end_token_idx: int = None): + """Can choose to add start token in the beginning and end token in the end. + + Args: + vec: source list composed of indexes. + start_token_idx: index of start token. + end_token_idx: index of end token. + + Returns: + list: list added start or end token index. + + """ + res = copy(vec) + if start_token_idx: + res.insert(0, start_token_idx) + if end_token_idx: + res.append(end_token_idx) + return res + + +def truncate(vec, max_length, truncate_tail=True): + """truncate vec to make its length no more than max length. + + Args: + vec (list): source list. + max_length (int) + truncate_tail (bool, optional): Defaults to True. + + Returns: + list: truncated vec. + + """ + if max_length is None: + return vec + if len(vec) <= max_length: + return vec + if max_length == 0: + return [] + if truncate_tail: + return vec[:max_length] + else: + return vec[-max_length:] + + +def merge_utt(conversation, start_token_idx=None, split_token_idx=None, keep_split_in_tail=False, final_token_idx=None): + """merge utterances in one conversation. + + Args: + conversation (list of list of int): conversation consist of utterances consist of tokens. + split_token_idx (int): index of split token. Defaults to None. + keep_split_in_tail (bool): split in tail or head. Defaults to False. + final_token_idx (int): index of final token. Defaults to None. + + Returns: + list: tokens of all utterances in one list. + + """ + merged_conv = [] + if start_token_idx: + merged_conv.append(start_token_idx) + for utt in conversation: + for token in utt: + merged_conv.append(token) + if split_token_idx: + merged_conv.append(split_token_idx) + if split_token_idx and not keep_split_in_tail: + merged_conv = merged_conv[:-1] + if final_token_idx: + merged_conv.append(final_token_idx) + return merged_conv + +def merge_utt_replace(conversation,detect_token=None,replace_token=None,method="in"): + if method == 'in': + replaced_conv = [] + for utt in conversation: + for token in utt: + if detect_token in token: + replaced_conv.append(replace_token) + else: + replaced_conv.append(token) + return replaced_conv + else: + return [token.replace(detect_token,replace_token) for utt in conversation for token in utt] diff --git a/HiCore/crslab/data/dataset/__init__.py b/HiCore/crslab/data/dataset/__init__.py new file mode 100644 index 0000000..8e78702 --- /dev/null +++ b/HiCore/crslab/data/dataset/__init__.py @@ -0,0 +1,8 @@ +from .base import BaseDataset +from .hredial import HReDialDataset +from .htgredial import HTGReDialDataset +from .opendialkg import OpenDialKGDataset +from .durecdial import DuRecDialDataset + +from .redial import ReDialDataset +from .tgredial import TGReDialDataset \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/base.py b/HiCore/crslab/data/dataset/base.py new file mode 100644 index 0000000..cbd962b --- /dev/null +++ b/HiCore/crslab/data/dataset/base.py @@ -0,0 +1,171 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2020/12/13 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +import os +import pickle as pkl +from abc import ABC, abstractmethod + +import numpy as np +from loguru import logger + +from crslab.download import build + + +class BaseDataset(ABC): + """Abstract class of dataset + + Notes: + ``'embedding'`` can be specified in config to use pretrained word embedding. + + """ + + def __init__(self, opt, dpath, resource, restore=False, save=False): + """Download resource, load, process data. Support restore and save processed dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + dpath (str): where to store dataset. + resource (dict): version, download file and special token idx of tokenized dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + self.opt = opt + self.dpath = dpath + + # download + dfile = resource['file'] + build(dpath, dfile, version=resource['version']) + + if not restore: + # load and process + train_data, valid_data, test_data, self.vocab = self._load_data() + logger.info('[Finish data load]') + self.train_data, self.valid_data, self.test_data, self.side_data = self._data_preprocess(train_data, valid_data, test_data) + embedding = opt.get('embedding', None) + if embedding: + self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding)) + logger.debug(f'[Load pretrained embedding {embedding}]') + logger.info('[Finish data preprocess]') + else: + self._load_other_data() + self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab = self._load_from_restore() + + if save: + data = (self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab) + self._save_to_one(data) + + @abstractmethod + def _load_other_data(self): + """ + Load review. + """ + pass + + @abstractmethod + def _load_data(self): + """Load dataset. + + Returns: + (any, any, any, dict): + + raw train, valid and test data. + + vocab: all kinds of useful size, idx and map between token and idx. + + """ + pass + + @abstractmethod + def _data_preprocess(self, train_data, valid_data, test_data): + """Process raw train, valid, test data. + + Args: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + + Returns: + (list of dict, dict): + + train/valid/test_data, each dict is in the following format:: + + { + 'role' (str): + 'Seeker' or 'Recommender', + 'user_profile' (list of list of int): + id of tokens of sentences of user profile, + 'context_tokens' (list of list int): + token ids of preprocessed contextual dialogs, + 'response' (list of int): + token ids of the ground-truth response, + 'interaction_history' (list of int): + id of items which have interaction of the user in current turn, + 'context_items' (list of int): + item ids mentioned in context, + 'items' (list of int): + item ids mentioned in current turn, we only keep + those in entity kg for comparison, + 'context_entities' (list of int): + if necessary, id of entities in context, + 'context_words' (list of int): + if necessary, id of words in context, + 'context_policy' (list of list of list): + policy of each context turn, one turn may have several policies, + where first is action and second is keyword, + 'target' (list): policy of current turn, + 'final' (list): final goal for current turn + } + + side_data, which is in the following format:: + + { + 'entity_kg': { + 'edge' (list of tuple): (head_entity_id, tail_entity_id, relation_id), + 'n_relation' (int): number of distinct relations, + 'entity' (list of str): str of entities, used for entity linking + } + 'word_kg': { + 'edge' (list of tuple): (head_entity_id, tail_entity_id), + 'entity' (list of str): str of entities, used for entity linking + } + 'item_entity_ids' (list of int): entity id of each item; + } + + """ + pass + + def _load_from_restore(self, file_name="all_data.pkl"): + """Restore saved dataset. + + Args: + file_name (str): file of saved dataset. Defaults to "all_data.pkl". + + """ + if not os.path.exists(os.path.join(self.dpath, file_name)): + raise ValueError(f'Saved dataset [{file_name}] does not exist') + with open(os.path.join(self.dpath, file_name), 'rb') as f: + dataset = pkl.load(f) + logger.info(f'Restore dataset from [{file_name}]') + return dataset + + def _save_to_one(self, data, file_name="all_data.pkl"): + """Save all processed dataset and vocab into one file. + + Args: + data (tuple): all dataset and vocab. + file_name (str, optional): file to save dataset. Defaults to "all_data.pkl". + + """ + if not os.path.exists(self.dpath): + os.makedirs(self.dpath) + save_path = os.path.join(self.dpath, file_name) + with open(save_path, 'wb') as f: + pkl.dump(data, f) + logger.info(f'[Save dataset to {file_name}]') diff --git a/HiCore/crslab/data/dataset/durecdial/__init__.py b/HiCore/crslab/data/dataset/durecdial/__init__.py new file mode 100644 index 0000000..1c3bf79 --- /dev/null +++ b/HiCore/crslab/data/dataset/durecdial/__init__.py @@ -0,0 +1 @@ +from .durecdial import DuRecDialDataset diff --git a/HiCore/crslab/data/dataset/durecdial/durecdial.py b/HiCore/crslab/data/dataset/durecdial/durecdial.py new file mode 100644 index 0000000..56e8c23 --- /dev/null +++ b/HiCore/crslab/data/dataset/durecdial/durecdial.py @@ -0,0 +1,281 @@ +# @Time : 2020/12/21 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/12/21, 2021/1/2 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +r""" +DuRecDial +========= +References: + Liu, Zeming, et al. `"Towards Conversational Recommendation over Multi-Type Dialogs."`_ in ACL 2020. + +.. _"Towards Conversational Recommendation over Multi-Type Dialogs.": + https://www.aclweb.org/anthology/2020.acl-main.98/ + +""" + +import json +import os +from copy import copy + +from loguru import logger +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class DuRecDialDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """ + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] + self.unk_token_idx = self.special_token_idx['unk'] + dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize) + self.train_review = None + self.valid_review = None + self.test_review = None + super().__init__(opt, dpath, resource, restore, save) + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'word2id': self.word2id, + 'vocab_size': len(self.tok2ind), + 'n_entity': self.n_entity, + 'n_word': self.n_word, + } + vocab.update(self.special_token_idx) + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + def _load_other_data(self): + # entity kg + with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f: + self.entity2id = json.load(f) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + # {head_entity_id: [(relation_id, tail_entity_id)]} + self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8') + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]") + + # hownet + # {concept: concept_id} + with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f: + self.word2id = json.load(f) + self.n_word = max(self.word2id.values()) + 1 + # {concept \t relation\t concept} + self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8') + logger.debug( + f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]") + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self._side_data_process() + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + def _convert_to_id(self, idx, conversation): + augmented_convs = [] + last_role = None + conv_id = conversation.get("conv_id", idx) + related_item = [] + related_entity = [] + related_word = [] + for utt in conversation['dialog']: + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + + related_item += item_ids + related_entity += entity_ids + related_word += word_ids + + if utt["role"] == last_role: + augmented_convs[-1]["text"] += text_token_ids + augmented_convs[-1]["item"] += item_ids + augmented_convs[-1]["entity"] += entity_ids + augmented_convs[-1]["word"] += word_ids + else: + augmented_convs.append({ + "conv_id": conv_id, + "role": utt["role"], + "text": text_token_ids, + "item": related_item, + "entity": related_entity, + "word": related_word + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_items = [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context_tokens) > 0: + conv_dict = { + "conv_id": conv["conv_id"], + "role": conv["role"], + "tokens": copy(context_tokens), + "response": text_tokens, + "item": copy(context_items), + "entity": copy(context_entities), + "word": copy(context_words), + "items": movies + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts + + def _side_data_process(self): + processed_entity_kg = self._entity_kg_process() + logger.debug("[Finish entity KG process]") + processed_word_kg = self._word_kg_process() + logger.debug("[Finish word KG process]") + with open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8') as f: + item_entity_ids = json.load(f) + logger.debug('[Load movie entity ids]') + + side_data = { + "entity_kg": processed_entity_kg, + "word_kg": processed_word_kg, + "item_entity_ids": item_entity_ids, + } + return side_data + + def _entity_kg_process(self): + edge_list = [] # [(entity, entity, relation)] + for line in self.entity_kg: + triple = line.strip().split('\t') + e0 = self.entity2id[triple[0]] + e1 = self.entity2id[triple[2]] + r = triple[1] + edge_list.append((e0, e1, r)) + edge_list.append((e1, e0, r)) + edge_list.append((e0, e0, 'SELF_LOOP')) + if e1 != e0: + edge_list.append((e1, e1, 'SELF_LOOP')) + + relation2id, edges, entities = dict(), set(), set() + for h, t, r in edge_list: + if r not in relation2id: + relation2id[r] = len(relation2id) + edges.add((h, t, relation2id[r])) + entities.add(self.id2entity[h]) + entities.add(self.id2entity[t]) + + return { + 'edge': list(edges), + 'n_relation': len(relation2id), + 'entity': list(entities) + } + + def _word_kg_process(self): + edges = set() # {(entity, entity)} + entities = set() + for line in self.word_kg: + triple = line.strip().split('\t') + entities.add(triple[0]) + entities.add(triple[2]) + e0 = self.word2id[triple[0]] + e1 = self.word2id[triple[2]] + edges.add((e0, e1)) + edges.add((e1, e0)) + # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]] + return { + 'edge': list(edges), + 'entity': list(entities) + } diff --git a/HiCore/crslab/data/dataset/durecdial/resources.py b/HiCore/crslab/data/dataset/durecdial/resources.py new file mode 100644 index 0000000..6bb858f --- /dev/null +++ b/HiCore/crslab/data/dataset/durecdial/resources.py @@ -0,0 +1,70 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'jieba': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1', + 'durecdial_jieba.zip', + 'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1', + 'durecdial_bert.zip', + '0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1', + 'durecdial_gpt2.zip', + 'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } +} diff --git a/HiCore/crslab/data/dataset/hredial/__init__.py b/HiCore/crslab/data/dataset/hredial/__init__.py new file mode 100644 index 0000000..f2130bb --- /dev/null +++ b/HiCore/crslab/data/dataset/hredial/__init__.py @@ -0,0 +1 @@ +from .hredial import HReDialDataset \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/hredial/hredial.py b/HiCore/crslab/data/dataset/hredial/hredial.py new file mode 100644 index 0000000..5c0f0d7 --- /dev/null +++ b/HiCore/crslab/data/dataset/hredial/hredial.py @@ -0,0 +1,209 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2021/1/3, 2020/12/19 +# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail + +r""" +ReDial +====== +References: + Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018. + +.. _`"Towards deep conversational recommendations."`: + https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html + +""" + +import json +import os +import pickle as pkl +from copy import copy + +from loguru import logger +import torch +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class HReDialDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """Specify tokenized resource and init base dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + dpath = os.path.join(DATASET_PATH, "hredial", tokenize) + super().__init__(opt, dpath, resource, restore, save) + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'vocab_size': len(self.tok2ind), + 'n_entity': self.n_entity + } + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + # load train/valid/test data + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + def _load_other_data(self): + # edge extension data + self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8')) + # dbpedia + self.entity2id = json.load( + open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb')) + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self.side_data + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + # 将文本、电影、实体信息转换为序号 + def _convert_to_id(self, conversation): + augmented_convs = [] + last_role = None + conv_id = conversation["conv_id"] + related_item = [] + related_entity = [] + related_word = [] + for utt in conversation['dialog']: + self.unk_token_idx = 3 + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind] + + related_item += item_ids + related_entity += entity_ids + related_word += word_ids + + if utt["role"] == last_role: + augmented_convs[-1]["text"] += text_token_ids + augmented_convs[-1]["item"] += item_ids + augmented_convs[-1]["entity"] += entity_ids + augmented_convs[-1]["word"] += word_ids + else: + augmented_convs.append({ + "conv_id": conv_id, + "role": utt["role"], + "text": text_token_ids, + "item": related_item, + "entity": related_entity, + "word": related_word + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_items = [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context_tokens) > 0: + conv_dict = { + "conv_id": conv["conv_id"], + "role": conv["role"], + "tokens": context_tokens, + "response": text_tokens, + "item": context_items, + "entity": context_entities, + "word": context_words, + "items": movies + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/hredial/resources.py b/HiCore/crslab/data/dataset/hredial/resources.py new file mode 100644 index 0000000..cd3d6b2 --- /dev/null +++ b/HiCore/crslab/data/dataset/hredial/resources.py @@ -0,0 +1,66 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/1 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'nltk': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', + 'redial_nltk.zip', + '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + 'bert': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', + 'redial_bert.zip', + 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', + 'redial_gpt2.zip', + '37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } +} diff --git a/HiCore/crslab/data/dataset/htgredial/__init__.py b/HiCore/crslab/data/dataset/htgredial/__init__.py new file mode 100644 index 0000000..3c7d129 --- /dev/null +++ b/HiCore/crslab/data/dataset/htgredial/__init__.py @@ -0,0 +1 @@ +from .htgredial import HTGReDialDataset diff --git a/HiCore/crslab/data/dataset/htgredial/htgredial.py b/HiCore/crslab/data/dataset/htgredial/htgredial.py new file mode 100644 index 0000000..cfcc58c --- /dev/null +++ b/HiCore/crslab/data/dataset/htgredial/htgredial.py @@ -0,0 +1,210 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2021/1/3, 2020/12/19 +# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail + +r""" +ReDial +====== +References: + Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018. + +.. _`"Towards deep conversational recommendations."`: + https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html + +""" + +import json +import os +import pickle as pkl +from copy import copy + +from loguru import logger +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class HTGReDialDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """Specify tokenized resource and init base dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + dpath = os.path.join(DATASET_PATH, "htgredial", tokenize) + super().__init__(opt, dpath, resource, restore, save) + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'vocab_size': len(self.tok2ind), + 'n_entity': self.n_entity + } + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + # load train/valid/test data + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + def _load_other_data(self): + # edge extension data + self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8')) + # hypergraph + self.user2hypergraph = json.load(open(os.path.join(self.dpath, 'user2hypergraph.json'), 'r', encoding='utf-8')) + # dbpedia + self.entity2id = json.load( + open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb')) + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self.side_data + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + # 将文本、电影、实体信息转换为序号 + def _convert_to_id(self, conversation): + augmented_convs = [] + last_role = None + conv_id = conversation["conv_id"] + related_item = [] + related_entity = [] + related_word = [] + for utt in conversation['dialog']: + self.unk_token_idx = 3 + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind] + + related_item += item_ids + related_entity += entity_ids + related_word += word_ids + + if utt["role"] == last_role: + augmented_convs[-1]["text"] += text_token_ids + augmented_convs[-1]["item"] += item_ids + augmented_convs[-1]["entity"] += entity_ids + augmented_convs[-1]["word"] += word_ids + else: + augmented_convs.append({ + "conv_id": conv_id, + "role": utt["role"], + "text": text_token_ids, + "item": related_item, + "entity": related_entity, + "word": related_word + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_items = [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context_tokens) > 0: + conv_dict = { + "conv_id": conv["conv_id"], + "role": conv["role"], + "tokens": copy(context_tokens), + "response": text_tokens, + "item": copy(context_items), + "entity": copy(context_entities), + "word": copy(context_words), + "items": movies + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/htgredial/resources.py b/HiCore/crslab/data/dataset/htgredial/resources.py new file mode 100644 index 0000000..411bb43 --- /dev/null +++ b/HiCore/crslab/data/dataset/htgredial/resources.py @@ -0,0 +1,66 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/1 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'pkuseg': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', + 'redial_nltk.zip', + '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + 'bert': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', + 'redial_bert.zip', + 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', + 'redial_gpt2.zip', + '37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } +} diff --git a/HiCore/crslab/data/dataset/opendialkg/__init__.py b/HiCore/crslab/data/dataset/opendialkg/__init__.py new file mode 100644 index 0000000..3afab94 --- /dev/null +++ b/HiCore/crslab/data/dataset/opendialkg/__init__.py @@ -0,0 +1 @@ +from .opendialkg import OpenDialKGDataset diff --git a/HiCore/crslab/data/dataset/opendialkg/opendialkg.py b/HiCore/crslab/data/dataset/opendialkg/opendialkg.py new file mode 100644 index 0000000..2ec54b3 --- /dev/null +++ b/HiCore/crslab/data/dataset/opendialkg/opendialkg.py @@ -0,0 +1,286 @@ +# @Time : 2020/12/19 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/12/20, 2021/1/2 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +r""" +OpenDialKG +========== +References: + Moon, Seungwhan, et al. `"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`_ in ACL 2019. + +.. _`"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`: + https://www.aclweb.org/anthology/P19-1081/ + +""" + +import json +import os +from collections import defaultdict +from copy import copy + +from loguru import logger +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class OpenDialKGDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """Specify tokenized resource and init base dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] + self.unk_token_idx = self.special_token_idx['unk'] + dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize) + self.train_review = None + self.valid_review = None + self.test_review = None + super().__init__(opt, dpath, resource, restore, save) + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'word2id': self.word2id, + 'vocab_size': len(self.tok2ind), + 'n_entity': self.n_entity, + 'n_word': self.n_word, + } + vocab.update(self.special_token_idx) + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + # load train/valid/test data + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + def _load_other_data(self): + # opendialkg + self.entity2id = json.load( + open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + # {head_entity_id: [(relation_id, tail_entity_id)]} + self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8') + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]") + + # conceptnet + # {concept: concept_id} + self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.n_word = max(self.word2id.values()) + 1 + # {concept \t relation\t concept} + self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8') + logger.debug( + f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]") + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self._side_data_process() + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + def _convert_to_id(self, idx, conversation): + augmented_convs = [] + last_role = None + conv_id = conversation.get("conv_id", idx) + related_item = [] + related_entity = [] + related_word = [] + for utt in conversation['dialog']: + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + + related_item += item_ids + related_entity += entity_ids + related_word += word_ids + + if utt["role"] == last_role: + augmented_convs[-1]["text"] += text_token_ids + augmented_convs[-1]["item"] += item_ids + augmented_convs[-1]["entity"] += entity_ids + augmented_convs[-1]["word"] += word_ids + else: + augmented_convs.append({ + "conv_id": conv_id, + "role": utt["role"], + "text": text_token_ids, + "item": related_item, + "entity": related_entity, + "word": related_word + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_items = [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"] + if len(context_tokens) > 0: + conv_dict = { + "conv_id": conv["conv_id"], + "role": conv["role"], + "tokens": copy(context_tokens), + "response": text_tokens, + "item": copy(context_items), + "entity": copy(context_entities), + "word": copy(context_words), + "items": movies + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts + + def _side_data_process(self): + processed_entity_kg = self._entity_kg_process() + logger.debug("[Finish entity KG process]") + processed_word_kg = self._word_kg_process() + logger.debug("[Finish word KG process]") + item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8')) + logger.debug('[Load item entity ids]') + + side_data = { + "entity_kg": processed_entity_kg, + "word_kg": processed_word_kg, + "item_entity_ids": item_entity_ids, + } + return side_data + + def _entity_kg_process(self): + edge_list = [] # [(entity, entity, relation)] + for line in self.entity_kg: + triple = line.strip().split('\t') + if len(triple) != 3 or triple[0] not in self.entity2id or triple[2] not in self.entity2id: + continue + e0 = self.entity2id[triple[0]] + e1 = self.entity2id[triple[2]] + r = triple[1] + edge_list.append((e0, e1, r)) + # edge_list.append((e1, e0, r)) + edge_list.append((e0, e0, 'SELF_LOOP')) + if e1 != e0: + edge_list.append((e1, e1, 'SELF_LOOP')) + + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + for h, t, r in edge_list: + relation_cnt[r] += 1 + for h, t, r in edge_list: + if relation_cnt[r] > 20000: + if r not in relation2id: + relation2id[r] = len(relation2id) + edges.add((h, t, relation2id[r])) + entities.add(self.id2entity[h]) + entities.add(self.id2entity[t]) + + return { + 'edge': list(edges), + 'n_relation': len(relation2id), + 'entity': list(entities) + } + + def _word_kg_process(self): + edges = set() # {(entity, entity)} + entities = set() + for line in self.word_kg: + triple = line.strip().split('\t') + entities.add(triple[0]) + entities.add(triple[2]) + e0 = self.word2id[triple[0]] + e1 = self.word2id[triple[2]] + edges.add((e0, e1)) + edges.add((e1, e0)) + # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]] + return { + 'edge': list(edges), + 'entity': list(entities) + } diff --git a/HiCore/crslab/data/dataset/opendialkg/resources.py b/HiCore/crslab/data/dataset/opendialkg/resources.py new file mode 100644 index 0000000..9f7fb62 --- /dev/null +++ b/HiCore/crslab/data/dataset/opendialkg/resources.py @@ -0,0 +1,66 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/21 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'nltk': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1', + 'opendialkg_nltk.zip', + '6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1', + 'opendialkg_bert.zip', + '0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1', + 'opendialkg_gpt2.zip', + 'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } +} diff --git a/HiCore/crslab/data/dataset/redial/__init__.py b/HiCore/crslab/data/dataset/redial/__init__.py new file mode 100644 index 0000000..621b2be --- /dev/null +++ b/HiCore/crslab/data/dataset/redial/__init__.py @@ -0,0 +1 @@ +from .redial import ReDialDataset \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/redial/redial.py b/HiCore/crslab/data/dataset/redial/redial.py new file mode 100644 index 0000000..9d463e6 --- /dev/null +++ b/HiCore/crslab/data/dataset/redial/redial.py @@ -0,0 +1,269 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/23, 2021/1/3, 2020/12/19 +# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail + +r""" +ReDial +====== +References: + Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018. + +.. _`"Towards deep conversational recommendations."`: + https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html + +""" + +import json +import os +from collections import defaultdict +from copy import copy + +from loguru import logger +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class ReDialDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """Specify tokenized resource and init base dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] + self.unk_token_idx = self.special_token_idx['unk'] + dpath = os.path.join(DATASET_PATH, "redial", tokenize) + super().__init__(opt, dpath, resource, restore, save) + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'word2id': self.word2id, + 'vocab_size': len(self.tok2ind), + 'n_entity': self.n_entity, + 'n_word': self.n_word, + } + vocab.update(self.special_token_idx) + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + # load train/valid/test data + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + def _load_other_data(self): + # dbpedia + self.entity2id = json.load( + open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + # {head_entity_id: [(relation_id, tail_entity_id)]} + self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8')) + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]") + + # conceptNet + # {concept: concept_id} + self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8')) + self.n_word = max(self.word2id.values()) + 1 + # {relation\t concept \t concept} + self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8') + logger.debug( + f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]") + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self._side_data_process() + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._merge_conv_data(conversation["dialog"]) for conversation in tqdm(raw_data)] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + def _merge_conv_data(self, dialog): + augmented_convs = [] + last_role = None + for utt in dialog: + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + + if utt["role"] == last_role: + augmented_convs[-1]["text"] += text_token_ids + augmented_convs[-1]["movie"] += movie_ids + augmented_convs[-1]["entity"] += entity_ids + augmented_convs[-1]["word"] += word_ids + else: + augmented_convs.append({ + "role": utt["role"], + "text": text_token_ids, + "entity": entity_ids, + "movie": movie_ids, + "word": word_ids + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_items = [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"] + if len(context_tokens) > 0: + conv_dict = { + "conv_id": i, + "role": conv['role'], + "tokens": copy(context_tokens), + "response": text_tokens, + "entity": copy(context_entities), + "word": copy(context_words), + "item": copy(context_items), + "items": movies, + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts + + def _side_data_process(self): + processed_entity_kg = self._entity_kg_process() + logger.debug("[Finish entity KG process]") + processed_word_kg = self._word_kg_process() + logger.debug("[Finish word KG process]") + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + logger.debug('[Load movie entity ids]') + + side_data = { + "entity_kg": processed_entity_kg, + "word_kg": processed_word_kg, + "item_entity_ids": movie_entity_ids, + } + return side_data + + def _entity_kg_process(self, SELF_LOOP_ID=185): + edge_list = [] # [(entity, entity, relation)] + for entity in range(self.n_entity): + if str(entity) not in self.entity_kg: + continue + edge_list.append((entity, entity, SELF_LOOP_ID)) # add self loop + for tail_and_relation in self.entity_kg[str(entity)]: + if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID: + edge_list.append((entity, tail_and_relation[1], tail_and_relation[0])) + edge_list.append((tail_and_relation[1], entity, tail_and_relation[0])) + + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + for h, t, r in edge_list: + relation_cnt[r] += 1 + for h, t, r in edge_list: + if relation_cnt[r] > 1000: + if r not in relation2id: + relation2id[r] = len(relation2id) + edges.add((h, t, relation2id[r])) + entities.add(self.id2entity[h]) + entities.add(self.id2entity[t]) + return { + 'edge': list(edges), + 'n_relation': len(relation2id), + 'entity': list(entities) + } + + def _word_kg_process(self): + edges = set() # {(entity, entity)} + entities = set() + for line in self.word_kg: + kg = line.strip().split('\t') + entities.add(kg[1].split('/')[0]) + entities.add(kg[2].split('/')[0]) + e0 = self.word2id[kg[1].split('/')[0]] + e1 = self.word2id[kg[2].split('/')[0]] + edges.add((e0, e1)) + edges.add((e1, e0)) + # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]] + return { + 'edge': list(edges), + 'entity': list(entities) + } \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/redial/resources.py b/HiCore/crslab/data/dataset/redial/resources.py new file mode 100644 index 0000000..7194843 --- /dev/null +++ b/HiCore/crslab/data/dataset/redial/resources.py @@ -0,0 +1,66 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/1 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'nltk': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1', + 'redial_nltk.zip', + '01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0 + }, + }, + 'bert': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1', + 'redial_bert.zip', + 'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + }, + }, + 'gpt2': { + 'version': '0.31', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1', + 'redial_gpt2.zip', + '15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8', + ), + 'special_token_idx': { + 'pad': -100, + 'start': 1, + 'end': 2, + 'unk': 3, + 'sent_split': 4, + 'word_split': 5, + 'pad_entity': 0, + 'pad_word': 0 + }, + } +} \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/tgredial/__init__.py b/HiCore/crslab/data/dataset/tgredial/__init__.py new file mode 100644 index 0000000..1cd18b9 --- /dev/null +++ b/HiCore/crslab/data/dataset/tgredial/__init__.py @@ -0,0 +1 @@ +from .tgredial import TGReDialDataset \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/tgredial/resources.py b/HiCore/crslab/data/dataset/tgredial/resources.py new file mode 100644 index 0000000..6c0e821 --- /dev/null +++ b/HiCore/crslab/data/dataset/tgredial/resources.py @@ -0,0 +1,71 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/4 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'pkuseg': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1', + 'tgredial_pkuseg.zip', + '8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6', + ), + 'special_token_idx': { + 'pad': 0, + 'start': 1, + 'end': 2, + 'unk': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, + }, + 'bert': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1', + 'tgredial_bert.zip', + 'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0 + }, + }, + 'gpt2': { + 'version': '0.3', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1', + 'tgredial_gpt2.zip', + '2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9' + ), + 'special_token_idx': { + 'pad': 0, + 'start': 101, + 'end': 102, + 'unk': 100, + 'cls': 101, + 'sep': 102, + 'sent_split': 2, + 'word_split': 3, + 'pad_entity': 0, + 'pad_word': 0, + 'pad_topic': 0, + }, + } +} \ No newline at end of file diff --git a/HiCore/crslab/data/dataset/tgredial/tgredial.py b/HiCore/crslab/data/dataset/tgredial/tgredial.py new file mode 100644 index 0000000..5fbd7bc --- /dev/null +++ b/HiCore/crslab/data/dataset/tgredial/tgredial.py @@ -0,0 +1,343 @@ +# @Time : 2020/12/4 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/12/6, 2021/1/2, 2020/12/19 +# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou +# @Email : francis_kun_zhou@163.com, sdzyh002@gmail + +r""" +TGReDial +======== +References: + Zhou, Kun, et al. `"Towards Topic-Guided Conversational Recommender System."`_ in COLING 2020. + +.. _`"Towards Topic-Guided Conversational Recommender System."`: + https://www.aclweb.org/anthology/2020.coling-main.365/ + +""" + +import json +import os +from collections import defaultdict +from copy import copy +import numpy as np +from loguru import logger +from tqdm import tqdm + +from crslab.config import DATASET_PATH +from crslab.data.dataset.base import BaseDataset +from .resources import resources + + +class TGReDialDataset(BaseDataset): + """ + + Attributes: + train_data: train dataset. + valid_data: valid dataset. + test_data: test dataset. + vocab (dict): :: + + { + 'tok2ind': map from token to index, + 'ind2tok': map from index to token, + 'topic2ind': map from topic to index, + 'ind2topic': map from index to topic, + 'entity2id': map from entity to index, + 'id2entity': map from index to entity, + 'word2id': map from word to index, + 'vocab_size': len(self.tok2ind), + 'n_topic': len(self.topic2ind) + 1, + 'n_entity': max(self.entity2id.values()) + 1, + 'n_word': max(self.word2id.values()) + 1, + } + + Notes: + ``'unk'`` and ``'pad_topic'`` must be specified in ``'special_token_idx'`` in ``resources.py``. + + """ + + def __init__(self, opt, tokenize, restore=False, save=False): + """Specify tokenized resource and init base dataset. + + Args: + opt (Config or dict): config for dataset or the whole system. + tokenize (str): how to tokenize dataset. + restore (bool): whether to restore saved dataset which has been processed. Defaults to False. + save (bool): whether to save dataset after processing. Defaults to False. + + """ + resource = resources[tokenize] + self.special_token_idx = resource['special_token_idx'] + self.unk_token_idx = self.special_token_idx['unk'] + self.pad_topic_idx = self.special_token_idx['pad_topic'] + dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize) + self.replace_token = opt.get('replace_token',None) + self.replace_token_idx = opt.get('replace_token_idx',None) + super().__init__(opt, dpath, resource, restore, save) + if self.replace_token: + if self.replace_token_idx: + self.side_data["embedding"][self.replace_token_idx] = self.side_data['embedding'][0] + else: + self.side_data["embedding"] = np.insert(self.side_data["embedding"],len(self.side_data["embedding"]),self.side_data['embedding'][0],axis=0) + + + def _load_data(self): + train_data, valid_data, test_data = self._load_raw_data() + self._load_vocab() + self._load_other_data() + + vocab = { + 'tok2ind': self.tok2ind, + 'ind2tok': self.ind2tok, + 'topic2ind': self.topic2ind, + 'ind2topic': self.ind2topic, + 'entity2id': self.entity2id, + 'id2entity': self.id2entity, + 'word2id': self.word2id, + 'vocab_size': len(self.tok2ind), + 'n_topic': len(self.topic2ind) + 1, + 'n_entity': self.n_entity, + 'n_word': self.n_word, + } + vocab.update(self.special_token_idx) + + return train_data, valid_data, test_data, vocab + + def _load_raw_data(self): + # load train/valid/test data + with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f: + train_data = json.load(f) + logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]") + with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f: + valid_data = json.load(f) + logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]") + with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]") + + return train_data, valid_data, test_data + + def _load_vocab(self): + self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8')) + self.ind2tok = {idx: word for word, idx in self.tok2ind.items()} + # add special tokens + if self.replace_token: + if self.replace_token not in self.tok2ind: + if self.replace_token_idx: + self.ind2tok[self.replace_token_idx] = self.replace_token + self.tok2ind[self.replace_token] = self.replace_token_idx + self.special_token_idx[self.replace_token] = self.replace_token_idx + else: + self.ind2tok[len(self.tok2ind)] = self.replace_token + self.tok2ind[self.replace_token] = len(self.tok2ind) + self.special_token_idx[self.replace_token] = len(self.tok2ind)-1 + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]") + + self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8')) + self.ind2topic = {idx: word for word, idx in self.topic2ind.items()} + + logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]") + logger.debug(f"[The size of token2index dictionary is {len(self.topic2ind)}]") + logger.debug(f"[The size of index2token dictionary is {len(self.ind2topic)}]") + + def _load_other_data(self): + # cn-dbpedia + self.entity2id = json.load( + open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id} + self.id2entity = {idx: entity for entity, idx in self.entity2id.items()} + self.n_entity = max(self.entity2id.values()) + 1 + # {head_entity_id: [(relation_id, tail_entity_id)]} + self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8') + logger.debug( + f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]") + + # hownet + # {concept: concept_id} + self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8')) + self.n_word = max(self.word2id.values()) + 1 + # {relation\t concept \t concept} + self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8') + logger.debug( + f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]") + + # user interaction history dictionary + self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8')) + logger.debug(f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]") + + # user profile + self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8')) + logger.debug(f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}") + + + def _data_preprocess(self, train_data, valid_data, test_data): + processed_train_data = self._raw_data_process(train_data) + logger.debug("[Finish train data process]") + processed_valid_data = self._raw_data_process(valid_data) + logger.debug("[Finish valid data process]") + processed_test_data = self._raw_data_process(test_data) + logger.debug("[Finish test data process]") + processed_side_data = self._side_data_process() + logger.debug("[Finish side data process]") + return processed_train_data, processed_valid_data, processed_test_data, processed_side_data + + def _raw_data_process(self, raw_data): + augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)] + augmented_conv_dicts = [] + for conv in tqdm(augmented_convs): + augmented_conv_dicts.extend(self._augment_and_add(conv)) + return augmented_conv_dicts + + def _convert_to_id(self, conversation): + augmented_convs = [] + last_role = None + for utt in conversation['messages']: + assert utt['role'] != last_role + # change movies into slots + if self.replace_token: + if len(utt['movie']) != 0: + while '《' in utt['text'] : + begin = utt['text'].index("《") + end = utt['text'].index("》") + utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:] + text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]] + movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id] + entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id] + word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id] + policy = [] + for action, kw in zip(utt['target'][1::2], utt['target'][2::2]): + if kw is None or action == '推荐电影': + continue + if isinstance(kw, str): + kw = [kw] + kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw] + policy.append([action, kw]) + final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]] + final = [utt['final'][0], final_kws] + conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id']) + interaction_history = self.conv2history.get(conv_utt_id, []) + user_profile = self.user2profile[conversation['user_id']] + user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile] + + augmented_convs.append({ + "role": utt["role"], + "text": text_token_ids, + "entity": entity_ids, + "movie": movie_ids, + "word": text_token_ids, + 'policy': policy, + 'final': final, + 'interaction_history': interaction_history, + 'user_profile': user_profile + }) + last_role = utt["role"] + + return augmented_convs + + def _augment_and_add(self, raw_conv_dict): + augmented_conv_dicts = [] + context_tokens, context_entities, context_words, context_policy, context_items = [], [], [], [], [] + entity_set, word_set = set(), set() + for i, conv in enumerate(raw_conv_dict): + text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \ + conv['policy'] + if self.replace_token is not None: + if text_tokens.count(30000) != len(movies): + continue # the number of slots doesn't equal to the number of movies + + if len(context_tokens) > 0: + conv_dict = { + 'conv_id': i, + 'role': conv['role'], + 'user_profile': conv['user_profile'], + "tokens": copy(context_tokens), + "response": text_tokens, + "entity": copy(context_entities), + "word": copy(context_words), + 'interaction_history': conv['interaction_history'], + 'item': copy(context_items), + "items": movies, + 'context_policy': copy(context_policy), + 'target': policies, + 'final': conv['final'], + } + augmented_conv_dicts.append(conv_dict) + + context_tokens.append(text_tokens) + context_policy.append(policies) + context_items += movies + for entity in entities + movies: + if entity not in entity_set: + entity_set.add(entity) + context_entities.append(entity) + for word in words: + if word not in word_set: + word_set.add(word) + context_words.append(word) + + return augmented_conv_dicts + + def _side_data_process(self): + processed_entity_kg = self._entity_kg_process() + logger.debug("[Finish entity KG process]") + processed_word_kg = self._word_kg_process() + logger.debug("[Finish word KG process]") + movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8')) + logger.debug('[Load movie entity ids]') + + side_data = { + "entity_kg": processed_entity_kg, + "word_kg": processed_word_kg, + "item_entity_ids": movie_entity_ids, + } + return side_data + + def _entity_kg_process(self): + edge_list = [] # [(entity, entity, relation)] + for line in self.entity_kg: + triple = line.strip().split('\t') + e0 = self.entity2id[triple[0]] + e1 = self.entity2id[triple[2]] + r = triple[1] + edge_list.append((e0, e1, r)) + edge_list.append((e1, e0, r)) + edge_list.append((e0, e0, 'SELF_LOOP')) + if e1 != e0: + edge_list.append((e1, e1, 'SELF_LOOP')) + + relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set() + for h, t, r in edge_list: + relation_cnt[r] += 1 + for h, t, r in edge_list: + if r not in relation2id: + relation2id[r] = len(relation2id) + edges.add((h, t, relation2id[r])) + entities.add(self.id2entity[h]) + entities.add(self.id2entity[t]) + + return { + 'edge': list(edges), + 'n_relation': len(relation2id), + 'entity': list(entities) + } + + def _word_kg_process(self): + edges = set() # {(entity, entity)} + entities = set() + for line in self.word_kg: + triple = line.strip().split('\t') + entities.add(triple[0]) + entities.add(triple[2]) + e0 = self.word2id[triple[0]] + e1 = self.word2id[triple[2]] + edges.add((e0, e1)) + edges.add((e1, e0)) + # edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]] + return { + 'edge': list(edges), + 'entity': list(entities) + } \ No newline at end of file diff --git a/HiCore/crslab/download.py b/HiCore/crslab/download.py new file mode 100644 index 0000000..835ad50 --- /dev/null +++ b/HiCore/crslab/download.py @@ -0,0 +1,275 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/7 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/7 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +import hashlib +import os +import shutil +import time + +import datetime +import requests +import tqdm +from loguru import logger + + +class DownloadableFile: + """ + A class used to abstract any file that has to be downloaded online. + + Any task that needs to download a file needs to have a list RESOURCES + that have objects of this class as elements. + + This class provides the following functionality: + + - Download a file from a URL + - Untar the file if zipped + - Checksum for the downloaded file + + An object of this class needs to be created with: + + - url : URL or Google Drive id to download from + - file_name : File name that the file should be named + - hashcode : SHA256 hashcode of the downloaded file + - zipped : False if the file is not compressed + - from_google : True if the file is from Google Drive + """ + + def __init__(self, url, file_name, hashcode, zipped=True, from_google=False): + self.url = url + self.file_name = file_name + self.hashcode = hashcode + self.zipped = zipped + self.from_google = from_google + + def checksum(self, dpath): + """ + Checksum on a given file. + + :param dpath: path to the downloaded file. + """ + sha256_hash = hashlib.sha256() + with open(os.path.join(dpath, self.file_name), "rb") as f: + for byte_block in iter(lambda: f.read(65536), b""): + sha256_hash.update(byte_block) + if sha256_hash.hexdigest() != self.hashcode: + # remove_dir(dpath) + raise AssertionError( + f"[ Checksum for {self.file_name} from \n{self.url}\n" + "does not match the expected checksum. Please try again. ]" + ) + else: + logger.debug("Checksum Successful") + pass + + def download_file(self, dpath): + if self.from_google: + download_from_google_drive(self.url, os.path.join(dpath, self.file_name)) + else: + download(self.url, dpath, self.file_name) + + self.checksum(dpath) + + if self.zipped: + untar(dpath, self.file_name) + + +def download(url, path, fname, redownload=False, num_retries=5): + """ + Download file using `requests`. + If ``redownload`` is set to false, then will not download tar file again if it is + present (default ``False``). + """ + outfile = os.path.join(path, fname) + download = not os.path.exists(outfile) or redownload + logger.info(f"Downloading {url} to {outfile}") + retry = num_retries + exp_backoff = [2 ** r for r in reversed(range(retry))] + + pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname)) + + while download and retry > 0: + response = None + try: + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.60', + } + response = requests.get(url, stream=True, headers=headers) + + # negative reply could be 'none' or just missing + CHUNK_SIZE = 32768 + total_size = int(response.headers.get('Content-Length', -1)) + # server returns remaining size if resuming, so adjust total + pbar.total = total_size + done = 0 + + with open(outfile, 'wb') as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if total_size > 0: + done += len(chunk) + if total_size < done: + # don't freak out if content-length was too small + total_size = done + pbar.total = total_size + pbar.update(len(chunk)) + break + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + ): + retry -= 1 + pbar.clear() + if retry > 0: + pl = 'y' if retry == 1 else 'ies' + logger.debug( + f'Connection error, retrying. ({retry} retr{pl} left)' + ) + time.sleep(exp_backoff[retry]) + else: + logger.error('Retried too many times, stopped retrying.') + finally: + if response: + response.close() + if retry <= 0: + raise RuntimeError('Connection broken too many times. Stopped retrying.') + + if download and retry > 0: + pbar.update(done - pbar.n) + if done < total_size: + raise RuntimeError( + f'Received less data than specified in Content-Length header for ' + f'{url}. There may be a download problem.' + ) + + pbar.close() + + +def _get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def download_from_google_drive(gd_id, destination): + """ + Use the requests package to download a file from Google Drive. + """ + URL = 'https://docs.google.com/uc?export=download' + + with requests.Session() as session: + response = session.get(URL, params={'id': gd_id}, stream=True) + token = _get_confirm_token(response) + + if token: + response.close() + params = {'id': gd_id, 'confirm': token} + response = session.get(URL, params=params, stream=True) + + CHUNK_SIZE = 32768 + with open(destination, 'wb') as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + response.close() + + +def move(path1, path2): + """ + Rename the given file. + """ + shutil.move(path1, path2) + + +def untar(path, fname, deleteTar=True): + """ + Unpack the given archive file to the same directory. + + :param str path: + The folder containing the archive. Will contain the contents. + + :param str fname: + The filename of the archive file. + + :param bool deleteTar: + If true, the archive will be deleted after extraction. + """ + logger.debug(f'unpacking {fname}') + fullpath = os.path.join(path, fname) + shutil.unpack_archive(fullpath, path) + if deleteTar: + os.remove(fullpath) + + +def make_dir(path): + """ + Make the directory and any nonexistent parent directories (`mkdir -p`). + """ + # the current working directory is a fine path + if path != '': + os.makedirs(path, exist_ok=True) + + +def remove_dir(path): + """ + Remove the given directory, if it exists. + """ + shutil.rmtree(path, ignore_errors=True) + + +def check_build(path, version_string=None): + """ + Check if '.built' flag has been set for that task. + + If a version_string is provided, this has to match, or the version is regarded as + not built. + """ + if version_string: + fname = os.path.join(path, '.built') + if not os.path.isfile(fname): + return False + else: + with open(fname, 'r') as read: + text = read.read().split('\n') + return len(text) > 1 and text[1] == version_string + else: + return os.path.isfile(os.path.join(path, '.built')) + + +def mark_done(path, version_string=None): + """ + Mark this path as prebuilt. + + Marks the path as done by adding a '.built' file with the current timestamp + plus a version description string if specified. + + :param str path: + The file path to mark as built. + + :param str version_string: + The version of this dataset. + """ + with open(os.path.join(path, '.built'), 'w') as write: + write.write(str(datetime.datetime.today())) + if version_string: + write.write('\n' + version_string) + + +def build(dpath, dfile, version=None): + if not check_build(dpath, version): + logger.info('[Building data: ' + dpath + ']') + if check_build(dpath): + remove_dir(dpath) + make_dir(dpath) + # Download the data. + downloadable_file = dfile + downloadable_file.download_file(dpath) + mark_done(dpath, version) diff --git a/HiCore/crslab/evaluator/__init__.py b/HiCore/crslab/evaluator/__init__.py new file mode 100644 index 0000000..cca3280 --- /dev/null +++ b/HiCore/crslab/evaluator/__init__.py @@ -0,0 +1,28 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/22 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from loguru import logger + +from .standard import StandardEvaluator +from ..data import dataset_language_map + +Evaluator_register_table = { + 'standard': StandardEvaluator +} + + +def get_evaluator(evaluator_name, dataset, file_path): + if evaluator_name in Evaluator_register_table: + language = dataset_language_map[dataset] + evaluator = Evaluator_register_table[evaluator_name](language, file_path) + logger.info(f'[Build evaluator {evaluator_name}]') + return evaluator + else: + raise NotImplementedError(f'Model [{evaluator_name}] has not been implemented') diff --git a/HiCore/crslab/evaluator/base.py b/HiCore/crslab/evaluator/base.py new file mode 100644 index 0000000..b1426f8 --- /dev/null +++ b/HiCore/crslab/evaluator/base.py @@ -0,0 +1,31 @@ +# @Time : 2020/11/30 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +# UPDATE: +# @Time : 2020/11/30 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +from abc import ABC, abstractmethod + + +class BaseEvaluator(ABC): + """Base class for evaluator""" + + def rec_evaluate(self, preds, label): + pass + + def gen_evaluate(self, preds, label): + pass + + def policy_evaluate(self, preds, label): + pass + + @abstractmethod + def report(self, epoch, mode): + pass + + @abstractmethod + def reset_metrics(self): + pass diff --git a/HiCore/crslab/evaluator/embeddings.py b/HiCore/crslab/evaluator/embeddings.py new file mode 100644 index 0000000..b7c30fd --- /dev/null +++ b/HiCore/crslab/evaluator/embeddings.py @@ -0,0 +1,30 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/18 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/18 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +resources = { + 'zh': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1', + 'cc.zh.300.zip', + 'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3' + ) + }, + 'en': { + 'version': '0.2', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1', + 'cc.en.300.zip', + '96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e' + ) + } +} diff --git a/HiCore/crslab/evaluator/metrics/__init__.py b/HiCore/crslab/evaluator/metrics/__init__.py new file mode 100644 index 0000000..cca4e70 --- /dev/null +++ b/HiCore/crslab/evaluator/metrics/__init__.py @@ -0,0 +1,4 @@ +from .base import Metric, Metrics, aggregate_unnamed_reports, AverageMetric +from .gen import BleuMetric, ExactMatchMetric, F1Metric, DistMetric, EmbeddingAverage, VectorExtrema, \ + GreedyMatch +from .rec import RECMetric, NDCGMetric, MRRMetric, CovMetric diff --git a/HiCore/crslab/evaluator/metrics/base.py b/HiCore/crslab/evaluator/metrics/base.py new file mode 100644 index 0000000..93e3189 --- /dev/null +++ b/HiCore/crslab/evaluator/metrics/base.py @@ -0,0 +1,232 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/2 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +import functools +from abc import ABC, abstractmethod + +import torch +from typing import Any, Union, List, Optional, Dict + +TScalar = Union[int, float, torch.Tensor] +TVector = Union[List[TScalar], torch.Tensor] + + +@functools.total_ordering +class Metric(ABC): + """ + Base class for storing metrics. + + Subclasses should define .value(). Examples are provided for each subclass. + """ + + @abstractmethod + def value(self): + """ + Return the value of the metric as a float. + """ + pass + + @abstractmethod + def __add__(self, other: Any) -> 'Metric': + raise NotImplementedError + + def __iadd__(self, other): + return self.__radd__(other) + + def __radd__(self, other: Any): + if other is None: + return self + return self.__add__(other) + + def __str__(self) -> str: + return f'{self.value():.4g}' + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.value():.4g})' + + def __float__(self) -> float: + return float(self.value()) + + def __int__(self) -> int: + return int(self.value()) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Metric): + return self.value() == other.value() + else: + return self.value() == other + + def __lt__(self, other: Any) -> bool: + if isinstance(other, Metric): + return self.value() < other.value() + else: + return self.value() < other + + def __sub__(self, other: Any) -> float: + """ + Used heavily for assertAlmostEqual. + """ + if not isinstance(other, float): + raise TypeError('Metrics.__sub__ is intentionally limited to floats.') + return self.value() - other + + def __rsub__(self, other: Any) -> float: + """ + Used heavily for assertAlmostEqual. + + NOTE: This is not necessary in python 3.7+. + """ + if not isinstance(other, float): + raise TypeError('Metrics.__rsub__ is intentionally limited to floats.') + return other - self.value() + + @classmethod + def as_number(cls, obj: TScalar) -> Union[int, float]: + if isinstance(obj, torch.Tensor): + obj_as_number: Union[int, float] = obj.item() + else: + obj_as_number = obj # type: ignore + assert isinstance(obj_as_number, int) or isinstance(obj_as_number, float) + return obj_as_number + + @classmethod + def as_float(cls, obj: TScalar) -> float: + return float(cls.as_number(obj)) + + @classmethod + def as_int(cls, obj: TScalar) -> int: + return int(cls.as_number(obj)) + + @classmethod + def many(cls, *objs: List[TVector]) -> List['Metric']: + """ + Construct many of a Metric from the base parts. + + Useful if you separately compute numerators and denomenators, etc. + """ + lengths = [len(o) for o in objs] + if len(set(lengths)) != 1: + raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}') + return [cls(*items) for items in zip(*objs)] + + +class SumMetric(Metric): + """ + Class that keeps a running sum of some metric. + + Examples of SumMetric include things like "exs", the number of examples seen since + the last report, which depends exactly on a teacher. + """ + + __slots__ = ('_sum',) + + def __init__(self, sum_: TScalar = 0): + if isinstance(sum_, torch.Tensor): + self._sum = sum_.item() + else: + assert isinstance(sum_, (int, float, list)) + self._sum = sum_ + + def __add__(self, other: Optional['SumMetric']) -> 'SumMetric': + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + full_sum = self._sum + other._sum + # always keep the same return type + return type(self)(sum_=full_sum) + + def value(self) -> float: + return self._sum + + +class AverageMetric(Metric): + """ + Class that keeps a running average of some metric. + + Examples of AverageMetrics include hits@1, F1, accuracy, etc. These metrics all have + per-example values that can be directly mapped back to a teacher. + """ + + __slots__ = ('_numer', '_denom') + + def __init__(self, numer: TScalar, denom: TScalar = 1): + self._numer = self.as_number(numer) + self._denom = self.as_number(denom) + + def __add__(self, other: Optional['AverageMetric']) -> 'AverageMetric': + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + full_numer: TScalar = self._numer + other._numer + full_denom: TScalar = self._denom + other._denom + # always keep the same return type + return type(self)(numer=full_numer, denom=full_denom) + + def value(self) -> float: + if self._numer == 0 and self._denom == 0: + # don't nan out if we haven't counted anything + return 0.0 + if self._denom == 0: + return float('nan') + return self._numer / self._denom + + +def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Metric]: + """ + Combines metrics without regard for tracking provenence. + """ + m: Dict[str, Metric] = {} + for task_report in reports: + for each_metric, value in task_report.items(): + m[each_metric] = m.get(each_metric) + value + return m + + +class Metrics(object): + """ + Metrics aggregator. + """ + + def __init__(self): + self._data = {} + + def __str__(self): + return str(self._data) + + def __repr__(self): + return f'Metrics({repr(self._data)})' + + def get(self, key: str): + if key in self._data.keys(): + return self._data[key].value() + else: + raise + + def __getitem__(self, item): + return self.get(item) + + def add(self, key: str, value: Optional[Metric]) -> None: + """ + Record an accumulation to a metric. + """ + self._data[key] = self._data.get(key) + value + + def report(self): + """ + Report the metrics over all data seen so far. + """ + return {k: (v if "cov" not in k else SumMetric(len(set(v.value())))) for k, v in self._data.items()} + + def clear(self): + """ + Clear all the metrics. + """ + self._data.clear() diff --git a/HiCore/crslab/evaluator/metrics/gen.py b/HiCore/crslab/evaluator/metrics/gen.py new file mode 100644 index 0000000..89b7a1a --- /dev/null +++ b/HiCore/crslab/evaluator/metrics/gen.py @@ -0,0 +1,158 @@ +# @Time : 2020/11/30 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +# UPDATE: +# @Time : 2020/12/18 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +import re +from collections import Counter + +import math +import numpy as np +from nltk import ngrams +from nltk.translate.bleu_score import sentence_bleu +from sklearn.metrics.pairwise import cosine_similarity +from typing import List, Optional + +from crslab.evaluator.metrics.base import AverageMetric, SumMetric + +re_art = re.compile(r'\b(a|an|the)\b') +re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') +re_space = re.compile(r'\s+') + + +class PPLMetric(AverageMetric): + def value(self): + return math.exp(super().value()) + + +def normalize_answer(s): + """ + Lower text and remove punctuation, articles and extra whitespace. + """ + + s = s.lower() + s = re_punc.sub(' ', s) + s = re_art.sub(' ', s) + s = re_space.sub(' ', s) + # s = ' '.join(s.split()) + return s + + +class ExactMatchMetric(AverageMetric): + @staticmethod + def compute(guess: str, answers: List[str]) -> 'ExactMatchMetric': + if guess is None or answers is None: + return None + for a in answers: + if guess == a: + return ExactMatchMetric(1) + return ExactMatchMetric(0) + + +class F1Metric(AverageMetric): + """ + Helper class which computes token-level F1. + """ + + @staticmethod + def _prec_recall_f1_score(pred_items, gold_items): + """ + Compute precision, recall and f1 given a set of gold and prediction items. + + :param pred_items: iterable of predicted values + :param gold_items: iterable of gold values + + :return: tuple (p, r, f1) for precision, recall, f1 + """ + common = Counter(gold_items) & Counter(pred_items) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_items) + recall = 1.0 * num_same / len(gold_items) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + @staticmethod + def compute(guess: str, answers: List[str]) -> 'F1Metric': + if guess is None or answers is None: + return AverageMetric(0, 0) + g_tokens = guess.split() + scores = [ + F1Metric._prec_recall_f1_score(g_tokens, a.split()) + for a in answers + ] + return F1Metric(max(scores), 1) + + +class BleuMetric(AverageMetric): + @staticmethod + def compute(guess: str, answers: List[str], k: int) -> Optional['BleuMetric']: + """ + Compute approximate BLEU score between guess and a set of answers. + """ + + weights = [0] * 4 + weights[k - 1] = 1 + score = sentence_bleu( + [a.split(" ") for a in answers], + guess.split(" "), + weights=weights, + ) + return BleuMetric(score) + + +class DistMetric(SumMetric): + @staticmethod + def compute(sent: str, k: int) -> 'DistMetric': + token_set = set() + for token in ngrams(sent.split(), k): + token_set.add(token) + return DistMetric(len(token_set)) + + +class EmbeddingAverage(AverageMetric): + @staticmethod + def _avg_embedding(embedding): + return np.sum(embedding, axis=0) / (np.linalg.norm(np.sum(embedding, axis=0)) + 1e-12) + + @staticmethod + def compute(hyp_embedding, ref_embeddings) -> 'EmbeddingAverage': + hyp_avg_emb = EmbeddingAverage._avg_embedding(hyp_embedding).reshape(1, -1) + ref_avg_embs = [EmbeddingAverage._avg_embedding(emb) for emb in ref_embeddings] + ref_avg_embs = np.array(ref_avg_embs) + return EmbeddingAverage(float(cosine_similarity(hyp_avg_emb, ref_avg_embs).max())) + + +class VectorExtrema(AverageMetric): + @staticmethod + def _extreme_embedding(embedding): + max_emb = np.max(embedding, axis=0) + min_emb = np.min(embedding, axis=0) + extreme_emb = np.fromiter( + map(lambda x, y: x if ((x > y or x < -y) and y > 0) or ((x < y or x > -y) and y < 0) else y, max_emb, + min_emb), dtype=float) + return extreme_emb + + @staticmethod + def compute(hyp_embedding, ref_embeddings) -> 'VectorExtrema': + hyp_ext_emb = VectorExtrema._extreme_embedding(hyp_embedding).reshape(1, -1) + ref_ext_embs = [VectorExtrema._extreme_embedding(emb) for emb in ref_embeddings] + ref_ext_embs = np.asarray(ref_ext_embs) + return VectorExtrema(float(cosine_similarity(hyp_ext_emb, ref_ext_embs).max())) + + +class GreedyMatch(AverageMetric): + @staticmethod + def compute(hyp_embedding, ref_embeddings) -> 'GreedyMatch': + hyp_emb = np.asarray(hyp_embedding) + ref_embs = (np.asarray(ref_embedding) for ref_embedding in ref_embeddings) + score_max = 0 + for ref_emb in ref_embs: + sim_mat = cosine_similarity(hyp_emb, ref_emb) + score_max = max(score_max, (sim_mat.max(axis=0).mean() + sim_mat.max(axis=1).mean()) / 2) + return GreedyMatch(score_max) diff --git a/HiCore/crslab/evaluator/metrics/rec.py b/HiCore/crslab/evaluator/metrics/rec.py new file mode 100644 index 0000000..5944dab --- /dev/null +++ b/HiCore/crslab/evaluator/metrics/rec.py @@ -0,0 +1,41 @@ +# @Time : 2020/11/30 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +# UPDATE: +# @Time : 2020/12/2 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com +import math + +from crslab.evaluator.metrics.base import AverageMetric, SumMetric + + +class RECMetric(AverageMetric): + @staticmethod + def compute(ranks, label, k) -> 'RECMetric': + return RECMetric(int(label in ranks[:k])) + + +class NDCGMetric(AverageMetric): + @staticmethod + def compute(ranks, label, k) -> 'NDCGMetric': + if label in ranks[:k]: + label_rank = ranks.index(label) + return NDCGMetric(1 / math.log2(label_rank + 2)) + return NDCGMetric(0) + + +class MRRMetric(AverageMetric): + @staticmethod + def compute(ranks, label, k) -> 'MRRMetric': + if label in ranks[:k]: + label_rank = ranks.index(label) + return MRRMetric(1 / (label_rank + 1)) + return MRRMetric(0) + + +class CovMetric(SumMetric): + @staticmethod + def compute(ranks, label, k) -> 'CovMetric': + return CovMetric(ranks[:k]) \ No newline at end of file diff --git a/HiCore/crslab/evaluator/standard.py b/HiCore/crslab/evaluator/standard.py new file mode 100644 index 0000000..e1db0cf --- /dev/null +++ b/HiCore/crslab/evaluator/standard.py @@ -0,0 +1,86 @@ +# @Time : 2020/11/30 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +# UPDATE: +# @Time : 2020/12/18 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +import os +import json + +from collections import defaultdict +from time import perf_counter +from loguru import logger +from nltk import ngrams + +from crslab.evaluator.base import BaseEvaluator +from crslab.evaluator.utils import nice_report +from .metrics import * + + +class StandardEvaluator(BaseEvaluator): + """The evaluator for all kind of model(recommender, conversation, policy) + + Args: + rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K + dist_set: the set to record dist n-gram + dist_cnt: the count of dist n-gram evaluation + gen_metrics: the metrics to evaluate conversational model, including bleu, dist, embedding metrics, f1 + optim_metrics: the metrics to optimize in training + """ + + def __init__(self, language, file_path=None): + super(StandardEvaluator, self).__init__() + self.file_path = file_path + self.result_data = [] + # rec + self.rec_metrics = Metrics() + # gen + self.dist_set = defaultdict(set) + self.dist_cnt = 0 + self.gen_metrics = Metrics() + # optim + self.optim_metrics = Metrics() + + def rec_evaluate(self, ranks, label): + for k in [1, 10, 50]: + if len(ranks) >= k: + self.rec_metrics.add(f"recall@{k}", RECMetric.compute(ranks, label, k)) + self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k)) + self.rec_metrics.add(f"mrr@{k}", MRRMetric.compute(ranks, label, k)) + + def gen_evaluate(self, hyp, refs, seq=None): + if hyp: + self.gen_metrics.add("f1", F1Metric.compute(hyp, refs)) + + for k in range(1, 5): + self.gen_metrics.add(f"bleu@{k}", BleuMetric.compute(hyp, refs, k)) + for token in ngrams(seq, k): + self.dist_set[f"dist@{k}"].add(token) + self.dist_cnt += 1 + + def report(self, epoch=-1, mode='test'): + for k, v in self.dist_set.items(): + self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt)) + reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()] + all_reports = aggregate_unnamed_reports(reports) + self.result_data.append({ + 'epoch': epoch, + 'mode': mode, + 'report': {k:all_reports[k].value() for k in all_reports} + }) + if self.file_path: + json.dump(self.result_data, open(self.file_path, "w", encoding="utf-8"), indent=4, ensure_ascii=False) + logger.info('\n' + nice_report(all_reports)) + + def reset_metrics(self): + # rec + self.rec_metrics.clear() + # conv + self.gen_metrics.clear() + self.dist_cnt = 0 + self.dist_set.clear() + # optim + self.optim_metrics.clear() diff --git a/HiCore/crslab/evaluator/utils.py b/HiCore/crslab/evaluator/utils.py new file mode 100644 index 0000000..c29a20a --- /dev/null +++ b/HiCore/crslab/evaluator/utils.py @@ -0,0 +1,160 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/17 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/12/17 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +import json +import re +import shutil +from collections import OrderedDict + +import math +import torch +from typing import Union, Tuple + +from .metrics import Metric + + +def _line_width(): + try: + # if we're in an interactive ipython notebook, hardcode a longer width + __IPYTHON__ # type: ignore + return 128 + except NameError: + return shutil.get_terminal_size((88, 24)).columns + + +def float_formatter(f: Union[float, int]) -> str: + """ + Format a float as a pretty string. + """ + if f != f: + # instead of returning nan, return "" so it shows blank in table + return "" + if isinstance(f, int): + # don't do any rounding of integers, leave them alone + return str(f) + if f >= 1000: + # numbers > 1000 just round to the nearest integer + s = f'{f:.0f}' + else: + # otherwise show 4 significant figures, regardless of decimal spot + s = f'{f:.4g}' + # replace leading 0's with blanks for easier reading + # example: -0.32 to -.32 + s = s.replace('-0.', '-.') + if s.startswith('0.'): + s = s[1:] + # Add the trailing 0's to always show 4 digits + # example: .32 to .3200 + if s[0] == '.' and len(s) < 5: + s += '0' * (5 - len(s)) + return s + + +def round_sigfigs(x: Union[float, 'torch.Tensor'], sigfigs=4) -> float: + """ + Round value to specified significant figures. + + :param x: input number + :param sigfigs: number of significant figures to return + + :returns: float number rounded to specified sigfigs + """ + x_: float + if isinstance(x, torch.Tensor): + x_ = x.item() + else: + x_ = x # type: ignore + + try: + if x_ == 0: + return 0 + return round(x_, -math.floor(math.log10(abs(x_)) - sigfigs + 1)) + except (ValueError, OverflowError) as ex: + if x_ in [float('inf'), float('-inf')] or x_ != x_: # inf or nan + return x_ + else: + raise ex + + +def _report_sort_key(report_key: str) -> Tuple[str, str]: + """ + Sorting name for reports. + + Sorts by main metric alphabetically, then by task. + """ + # if metric is on its own, like "f1", we will return ('', 'f1') + # if metric is from multitask, we denote it. + # e.g. "convai2/f1" -> ('convai2', 'f1') + # we handle multiple cases of / because sometimes teacher IDs have + # filenames. + fields = report_key.split("/") + main_key = fields.pop(-1) + sub_key = '/'.join(fields) + return (sub_key or 'all', main_key) + + +def nice_report(report) -> str: + """ + Render an agent Report as a beautiful string. + + If pandas is installed, we will use it to render as a table. Multitask + metrics will be shown per row, e.g. + + .. code-block: + f1 ppl + all .410 27.0 + task1 .400 32.0 + task2 .420 22.0 + + If pandas is not available, we will use a dict with like-metrics placed + next to each other. + """ + if not report: + return "" + + try: + import pandas as pd + + use_pandas = True + except ImportError: + use_pandas = False + + sorted_keys = sorted(report.keys(), key=_report_sort_key) + output: OrderedDict[Union[str, Tuple[str, str]], float] = OrderedDict() + for k in sorted_keys: + v = report[k] + if isinstance(v, Metric): + v = v.value() + if use_pandas: + output[_report_sort_key(k)] = v + else: + output[k] = v + + if use_pandas: + line_width = _line_width() + + df = pd.DataFrame([output]) + df.columns = pd.MultiIndex.from_tuples(df.columns) + df = df.stack().transpose().droplevel(0, axis=1) + result = " " + df.to_string( + na_rep="", + line_width=line_width - 3, # -3 for the extra spaces we add + float_format=float_formatter, + index=df.shape[0] > 1, + ).replace("\n\n", "\n").replace("\n", "\n ") + result = re.sub(r"\s+$", "", result) + return result + else: + return json.dumps( + { + k: round_sigfigs(v, 4) if isinstance(v, float) else v + for k, v in output.items() + } + ) diff --git a/HiCore/crslab/model/__init__.py b/HiCore/crslab/model/__init__.py new file mode 100644 index 0000000..17f410d --- /dev/null +++ b/HiCore/crslab/model/__init__.py @@ -0,0 +1,34 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/24 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +# @Time : 2021/10/06 +# @Author : Zhipeng Zhao +# @Email : oran_official@outlook.com + +import torch +from loguru import logger + +from .crs import * + +Model_register_table = { + 'HiCore': HiCoreModel, +} + + +def get_model(config, model_name, device, vocab, side_data=None): + if model_name in Model_register_table: + model = Model_register_table[model_name](config, device, vocab, side_data) + logger.info(f'[Build model {model_name}]') + if config.opt["gpu"] == [-1]: + return model + else: + return torch.nn.DataParallel(model, device_ids=config["gpu"]) + + else: + raise NotImplementedError('Model [{}] has not been implemented'.format(model_name)) diff --git a/HiCore/crslab/model/base.py b/HiCore/crslab/model/base.py new file mode 100644 index 0000000..348f83d --- /dev/null +++ b/HiCore/crslab/model/base.py @@ -0,0 +1,62 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/29 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +from abc import ABC, abstractmethod + +from torch import nn + +from crslab.download import build + + +class BaseModel(ABC, nn.Module): + """Base class for all models""" + + def __init__(self, opt, device, dpath=None, resource=None): + super(BaseModel, self).__init__() + self.opt = opt + self.device = device + + if resource is not None: + self.dpath = dpath + dfile = resource['file'] + build(dpath, dfile, version=resource['version']) + + self.build_model() + + @abstractmethod + def build_model(self, *args, **kwargs): + """build model""" + pass + + def recommend(self, batch, mode): + """calculate loss and prediction of recommendation for batch under certain mode + + Args: + batch (dict or tuple): batch data + mode (str, optional): train/valid/test. + """ + pass + + def converse(self, batch, mode): + """calculate loss and prediction of conversation for batch under certain mode + + Args: + batch (dict or tuple): batch data + mode (str, optional): train/valid/test. + """ + pass + + def guide(self, batch, mode): + """calculate loss and prediction of guidance for batch under certain mode + + Args: + batch (dict or tuple): batch data + mode (str, optional): train/valid/test. + """ + pass diff --git a/HiCore/crslab/model/crs/__init__.py b/HiCore/crslab/model/crs/__init__.py new file mode 100644 index 0000000..dae0eeb --- /dev/null +++ b/HiCore/crslab/model/crs/__init__.py @@ -0,0 +1 @@ +from .hicore import * diff --git a/HiCore/crslab/model/crs/hicore/__init__.py b/HiCore/crslab/model/crs/hicore/__init__.py new file mode 100644 index 0000000..9c2981e --- /dev/null +++ b/HiCore/crslab/model/crs/hicore/__init__.py @@ -0,0 +1 @@ +from .hicore import HiCoreModel \ No newline at end of file diff --git a/HiCore/crslab/model/crs/hicore/attention.py b/HiCore/crslab/model/crs/hicore/attention.py new file mode 100644 index 0000000..a3aae14 --- /dev/null +++ b/HiCore/crslab/model/crs/hicore/attention.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/4/1 +# @Author : Chenzhan Shang +# @Email : czshang@outlook.com + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Multi-head +class MHItemAttention(nn.Module): + def __init__(self, dim, head_num): + super(MHItemAttention, self).__init__() + self.MHA = torch.nn.MultiheadAttention(dim, head_num, batch_first=True) + + def forward(self, related_entity, context_entity): + """ + input: + related_entity: (n_r, dim) + context_entity: (n_c, dim) + output: + related_context_entity: (n_c, dim) + """ + context_entity = torch.unsqueeze(context_entity, 0) + related_entity = torch.unsqueeze(related_entity, 0) + output, _ = self.MHA(context_entity, related_entity, related_entity) + return torch.squeeze(output, 0) diff --git a/HiCore/crslab/model/crs/hicore/decoder.py b/HiCore/crslab/model/crs/hicore/decoder.py new file mode 100644 index 0000000..d87d12b --- /dev/null +++ b/HiCore/crslab/model/crs/hicore/decoder.py @@ -0,0 +1,200 @@ +import torch +import numpy as np +from torch import nn +from crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _normalize, create_position_codes + + +class TransformerDecoderLayerKG(nn.Module): + def __init__( + self, + n_heads, + embedding_size, + ffn_size, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + ): + super().__init__() + self.dim = embedding_size + self.ffn_dim = ffn_size + self.dropout = nn.Dropout(p=dropout) + + self.self_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm_self_attention = nn.LayerNorm(embedding_size) + + self.session_item_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm_session_item_attention = nn.LayerNorm(embedding_size) + + self.related_encoder_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm_related_encoder_attention = nn.LayerNorm(embedding_size) + + self.context_encoder_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm_context_encoder_attention = nn.LayerNorm(embedding_size) + + self.norm_merge = nn.LayerNorm(embedding_size) + + self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout) + self.norm3 = nn.LayerNorm(embedding_size) + + def forward(self, x, related_encoder_output, related_encoder_mask, context_encoder_output, context_encoder_mask, + session_embedding, session_mask): + decoder_mask = self._create_selfattn_mask(x) + # first self attn + residual = x + # don't peak into the future! + x = self.self_attention(query=x, mask=decoder_mask) + x = self.dropout(x) # --dropout + x = x + residual + x = _normalize(x, self.norm_self_attention) + + residual = x + x = self.session_item_attention( + query=x, + key=session_embedding, + value=session_embedding, + mask=session_mask + ) + x = self.dropout(x) + x = residual + x + x = _normalize(x, self.norm_session_item_attention) + + residual = x + related_x = self.related_encoder_attention( + query=x, + key=related_encoder_output, + value=related_encoder_output, + mask=related_encoder_mask + ) + related_x = self.dropout(related_x) # --dropout + + context_x = self.context_encoder_attention( + query=x, + key=context_encoder_output, + value=context_encoder_output, + mask=context_encoder_mask + ) + context_x = self.dropout(context_x) # --dropout + + x = related_x * 0.1 + context_x * 0.9 + residual + x = _normalize(x, self.norm_merge) + + # finally the ffn + residual = x + x = self.ffn(x) + x = self.dropout(x) # --dropout + x = residual + x + x = _normalize(x, self.norm3) + + return x + + def _create_selfattn_mask(self, x): + # figure out how many timestamps we need + bsz = x.size(0) + time = x.size(1) + # make sure that we don't look into the future + mask = torch.tril(x.new(time, time).fill_(1)) + # broadcast across batch + mask = mask.unsqueeze(0).expand(bsz, -1, -1) + return mask + + +class TransformerDecoderKG(nn.Module): + """ + Transformer Decoder layer. + + :param int n_heads: the number of multihead attention heads. + :param int n_layers: number of transformer layers. + :param int embedding_size: the embedding sizes. Must be a multiple of n_heads. + :param int ffn_size: the size of the hidden layer in the FFN + :param embedding: an embedding matrix for the bottom layer of the transformer. + If none, one is created for this encoder. + :param float dropout: Dropout used around embeddings and before layer + layer normalizations. This is used in Vaswani 2017 and works well on + large datasets. + :param float attention_dropout: Dropout performed after the multhead attention + softmax. This is not used in Vaswani 2017. + :param int padding_idx: Reserved padding index in the embeddings matrix. + :param bool learn_positional_embeddings: If off, sinusoidal embeddings are + used. If on, position embeddings are learned from scratch. + :param bool embeddings_scale: Scale embeddings relative to their dimensionality. + Found useful in fairseq. + :param int n_positions: Size of the position embeddings matrix. + """ + + def __init__( + self, + n_heads, + n_layers, + embedding_size, + ffn_size, + vocabulary_size, + embedding=None, + dropout=0.0, + attention_dropout=0.0, + relu_dropout=0.0, + embeddings_scale=True, + learn_positional_embeddings=False, + padding_idx=None, + n_positions=1024, + ): + super().__init__() + self.embedding_size = embedding_size + self.ffn_size = ffn_size + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = embedding_size + self.embeddings_scale = embeddings_scale + self.dropout = nn.Dropout(p=dropout) # --dropout + + self.out_dim = embedding_size + assert embedding_size % n_heads == 0, \ + 'Transformer embedding size must be a multiple of n_heads' + + self.embeddings = embedding + + # create the positional embeddings + self.position_embeddings = nn.Embedding(n_positions, embedding_size) + if not learn_positional_embeddings: + create_position_codes( + n_positions, embedding_size, out=self.position_embeddings.weight + ) + else: + nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5) + + # build the model + self.layers = nn.ModuleList() + for _ in range(self.n_layers): + self.layers.append(TransformerDecoderLayerKG( + n_heads, embedding_size, ffn_size, + attention_dropout=attention_dropout, + relu_dropout=relu_dropout, + dropout=dropout, + )) + + def forward(self, input, related_encoder_state, context_encoder_state, session_state, incr_state=None): + related_encoder_output, related_encoder_mask = related_encoder_state + context_encoder_output, context_encoder_mask = context_encoder_state + session_embedding, session_mask = session_state + + seq_len = input.shape[1] + positions = input.new_empty(seq_len).long() + positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len) + tensor = self.embeddings(input) + if self.embeddings_scale: + tensor = tensor * np.sqrt(self.dim) + tensor = tensor + self.position_embeddings(positions).expand_as(tensor) + tensor = self.dropout(tensor) # --dropout + + for layer in self.layers: + tensor = layer(tensor, related_encoder_output, related_encoder_mask, context_encoder_output, + context_encoder_mask, session_embedding, session_mask) + + return tensor, None diff --git a/HiCore/crslab/model/crs/hicore/hicore.py b/HiCore/crslab/model/crs/hicore/hicore.py new file mode 100644 index 0000000..bb0ee64 --- /dev/null +++ b/HiCore/crslab/model/crs/hicore/hicore.py @@ -0,0 +1,811 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/5/26 +# @Author : Chenzhan Shang +# @email : czshang@outlook.com + +r""" +PCR +==== +References: + Chen, Qibin, et al. `"Towards Knowledge-Based Recommender Dialog System."`_ in EMNLP 2019. + +.. _`"Towards Knowledge-Based Recommender Dialog System."`: + https://www.aclweb.org/anthology/D19-1189/ + +""" + +import json +import os.path +import random +import pickle +from typing import List +from time import perf_counter +from scipy import sparse + +import torch +import torch.nn.functional as F +from loguru import logger +from torch import nn +from tqdm import tqdm +from torch_geometric.nn import RGCNConv, HypergraphConv + +from crslab.config import DATA_PATH, DATASET_PATH +from crslab.model.base import BaseModel +from crslab.model.crs.hicore.attention import MHItemAttention +from crslab.model.utils.functions import edge_to_pyg_format +from crslab.model.utils.modules.attention import SelfAttentionBatch, SelfAttentionSeq +from crslab.model.utils.modules.transformer import TransformerEncoder +from crslab.model.crs.hicore.decoder import TransformerDecoderKG + +class GatingLayer(nn.Module): + def __init__(self, dim): + super(GatingLayer, self).__init__() + self.dim = dim + self.linear = nn.Linear(self.dim, self.dim) + self.activation = nn.Sigmoid() + + def forward(self, emb): + embedding = self.linear(emb) + embedding = self.activation(embedding) + embedding = torch.mul(emb, embedding) + return embedding + +class AttLayer(nn.Module): + def __init__(self, dim): + super(AttLayer, self).__init__() + self.dim = dim + self.attention_mat = nn.Parameter(torch.randn([self.dim, self.dim])) + self.attention = nn.Parameter(torch.randn([1, self.dim])) + + def forward(self, *embs): + embeddings = torch.cat(embs, dim=0) + + transformed = torch.matmul(embeddings, self.attention_mat) + weights = torch.sum(torch.mul(self.attention, transformed), dim=1) + score = torch.softmax(weights, dim=0) + + mixed_embeddings = torch.sum(embeddings * score.unsqueeze(1), dim=0) + return mixed_embeddings.unsqueeze(0) + +class HiCoreModel(BaseModel): + """ + + Attributes: + vocab_size: A integer indicating the vocabulary size. + pad_token_idx: A integer indicating the id of padding token. + start_token_idx: A integer indicating the id of start token. + end_token_idx: A integer indicating the id of end token. + token_emb_dim: A integer indicating the dimension of token embedding layer. + pretrain_embedding: A string indicating the path of pretrained embedding. + n_entity: A integer indicating the number of entities. + n_relation: A integer indicating the number of relation in KG. + num_bases: A integer indicating the number of bases. + kg_emb_dim: A integer indicating the dimension of kg embedding. + user_emb_dim: A integer indicating the dimension of user embedding. + n_heads: A integer indicating the number of heads. + n_layers: A integer indicating the number of layer. + ffn_size: A integer indicating the size of ffn hidden. + dropout: A float indicating the dropout rate. + attention_dropout: A integer indicating the dropout rate of attention layer. + relu_dropout: A integer indicating the dropout rate of relu layer. + learn_positional_embeddings: A boolean indicating if we learn the positional embedding. + embeddings_scale: A boolean indicating if we use the embeddings scale. + reduction: A boolean indicating if we use the reduction. + n_positions: A integer indicating the number of position. + longest_label: A integer indicating the longest length for response generation. + user_proj_dim: A integer indicating dim to project for user embedding. + + """ + + def __init__(self, opt, device, vocab, side_data): + """ + + Args: + opt (dict): A dictionary record the hyper parameters. + device (torch.device): A variable indicating which device to place the data and model. + vocab (dict): A dictionary record the vocabulary information. + side_data (dict): A dictionary record the side data. + + """ + self.device = device + self.gpu = opt.get("gpu", -1) + self.dataset = opt.get("dataset", None) + assert self.dataset in ['HReDial', 'HTGReDial', 'DuRecDial', 'OpenDialKG', 'ReDial', 'TGReDial'] + # vocab + self.pad_token_idx = vocab['tok2ind']['__pad__'] + self.start_token_idx = vocab['tok2ind']['__start__'] + self.end_token_idx = vocab['tok2ind']['__end__'] + self.vocab_size = vocab['vocab_size'] + self.token_emb_dim = opt.get('token_emb_dim', 300) + self.pretrain_embedding = side_data.get('embedding', None) + self.token2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "token2id.json"), "r", encoding="utf-8")) + self.entity2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "entity2id.json"), "r", encoding="utf-8")) + # kg + self.n_entity = vocab['n_entity'] + self.entity_kg = side_data['entity_kg'] + self.n_relation = self.entity_kg['n_relation'] + self.edge_idx, self.edge_type = edge_to_pyg_format(self.entity_kg['edge'], 'RGCN') + self.edge_idx = self.edge_idx.to(device) + self.edge_type = self.edge_type.to(device) + self.num_bases = opt.get('num_bases', 8) + self.kg_emb_dim = opt.get('kg_emb_dim', 300) + self.user_emb_dim = self.kg_emb_dim + # transformer + self.n_heads = opt.get('n_heads', 2) + self.n_layers = opt.get('n_layers', 2) + self.ffn_size = opt.get('ffn_size', 300) + self.dropout = opt.get('dropout', 0.1) + self.attention_dropout = opt.get('attention_dropout', 0.0) + self.relu_dropout = opt.get('relu_dropout', 0.1) + self.embeddings_scale = opt.get('embedding_scale', True) + self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False) + self.reduction = opt.get('reduction', False) + self.n_positions = opt.get('n_positions', 1024) + self.longest_label = opt.get('longest_label', 30) + self.user_proj_dim = opt.get('user_proj_dim', 512) + # pooling + self.pooling = opt.get('pooling', None) + assert self.pooling == 'Attn' or self.pooling == 'Mean' + # MHA + self.mha_n_heads = opt.get('mha_n_heads', 4) + self.extension_strategy = opt.get('extension_strategy', None) + self.pretrain = opt.get('pretrain', False) + self.pretrain_data = None + self.pretrain_epoch = opt.get('pretrain_epoch', 9999) + + super(HiCoreModel, self).__init__(opt, device) + return + + # 构建模型 + def build_model(self, *args, **kwargs): + if self.pretrain: + pretrain_file = os.path.join('pretrain', self.dataset, str(self.pretrain_epoch) + '-epoch.pth') + self.pretrain_data = torch.load(pretrain_file, map_location=torch.device('cuda:' + str(self.gpu[0]))) + logger.info(f"[Load Pretrain Weights from {pretrain_file}]") + # self._build_hredial_copy_mask() + self._build_adjacent_matrix() + # self._build_hllm_data() + self._build_embedding() + self._build_kg_layer() + self._build_recommendation_layer() + self._build_conversation_layer() + self._build_gating_layer() + + def _build_gating_layer(self): + self.gate_items_H_j = GatingLayer(self.kg_emb_dim) + self.gate_items_H_s = GatingLayer(self.kg_emb_dim) + self.gate_items_H_p = GatingLayer(self.kg_emb_dim) + self.gate_items_H_o = GatingLayer(self.kg_emb_dim) + self.gate_entitys_H_j = GatingLayer(self.kg_emb_dim) + self.gate_entitys_H_s = GatingLayer(self.kg_emb_dim) + self.gate_entitys_H_p = GatingLayer(self.kg_emb_dim) + self.gate_entitys_H_o = GatingLayer(self.kg_emb_dim) + self.gate_words_H_j = GatingLayer(self.kg_emb_dim) + self.gate_words_H_s = GatingLayer(self.kg_emb_dim) + self.gate_words_H_p = GatingLayer(self.kg_emb_dim) + self.gate_words_H_o = GatingLayer(self.kg_emb_dim) + self.attn_items = AttLayer(self.kg_emb_dim) + self.attn_entitys = AttLayer(self.kg_emb_dim) + self.attn_words = AttLayer(self.kg_emb_dim) + return + # 构建 mask + def _build_hredial_copy_mask(self): + token_filename = os.path.join(DATASET_PATH, "hredial", "nltk", "token2id.json") + token_file = open(token_filename, 'r', encoding="utf-8") + token2id = json.load(token_file) + id2token = {token2id[token]: token for token in token2id} + self.hredial_copy_mask = list() + for i in range(len(id2token)): + token = id2token[i] + if token[0] == '@': + self.hredial_copy_mask.append(True) + else: + self.hredial_copy_mask.append(False) + self.hredial_copy_mask = torch.as_tensor(self.hredial_copy_mask).to(self.device) + return + + def _build_hllm_data(self): + self.hllm_data_table = { + "train": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_train_data.pkl"), "rb")), + "valid": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_valid_data.pkl"), "rb")), + "test": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_test_data.pkl"), "rb")), + } + return + + def _build_adjacent_matrix(self): + graph = dict() + for head, tail, relation in tqdm(self.entity_kg['edge']): + graph[head] = graph.get(head, []) + [tail] + adj = dict() + for entity in tqdm(range(self.n_entity)): + adj[entity] = set() + if entity not in graph: + continue + last_hop = {entity} + for _ in range(1): + buffer = set() + for source in last_hop: + adj[entity].update(graph[source]) + buffer.update(graph[source]) + last_hop = buffer + self.adj = adj + self.u2e = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "u2e.npz"))) + self.u2i = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "u2i.npz"))) + self.u2w = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "u2w.npz"))) + self.e2e = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "e2e.npz"))) + self.i2i = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "i2i.npz"))) + self.w2w = sparse.csr_matrix(sparse.load_npz(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "mat", "w2w.npz"))) + self.items_H_s, self.items_H_j, self.items_H_p = self._build_motif_adj_matrix(self.i2i, self.u2i.T) + self.entitys_H_s, self.entitys_H_j, self.entitys_H_p = self._build_motif_adj_matrix(self.e2e, self.u2e.T) + self.words_H_s, self.words_H_j, self.words_H_p = self._build_motif_adj_matrix(self.w2w, self.u2w.T) + logger.info(f"[Adjacent Matrix built.]") + return + + def _build_motif_adj_matrix(self, net_matrix:sparse.csr_matrix, inter_matrix:sparse.csr_matrix) -> tuple[sparse.csr_matrix, sparse.csr_matrix, sparse.csr_matrix]: + S = net_matrix + Y = inter_matrix + B = S.multiply(S.T) + U = S - B + C1 = (U.dot(U)).multiply(U.T) + A1 = C1 + C1.T + C2 = (B.dot(U)).multiply(U.T) + (U.dot(B)).multiply(U.T) + (U.dot(U)).multiply(B) + A2 = C2 + C2.T + C3 = (B.dot(B)).multiply(U) + (B.dot(U)).multiply(B) + (U.dot(B)).multiply(B) + A3 = C3 + C3.T + A4 = (B.dot(B)).multiply(B) + C5 = (U.dot(U)).multiply(U) + (U.dot(U.T)).multiply(U) + (U.T.dot(U)).multiply(U) + A5 = C5 + C5.T + A6 = (U.dot(B)).multiply(U) + (B.dot(U.T)).multiply(U.T) + (U.T.dot(U)).multiply(B) + A7 = (U.T.dot(B)).multiply(U.T) + (B.dot(U)).multiply(U) + (U.dot(U.T)).multiply(B) + A8 = (Y.dot(Y.T)).multiply(B) + A9 = (Y.dot(Y.T)).multiply(U) + A9 = A9 + A9.T + A10 = Y.dot(Y.T) - A8 - A9 + H_s = sum([A1, A2, A3, A4, A5, A6, A7]) + H_s = H_s.multiply(1.0 / (H_s.sum(axis=1) + 1e-7).reshape(-1, 1)) + H_j = sum([A8, A9]) + H_j = H_j.multiply(1.0 / (H_j.sum(axis=1) + 1e-7).reshape(-1, 1)) + H_p = A10 + H_p = H_p.multiply(H_p > 1) + H_p = H_p.multiply(1.0 / (H_p.sum(axis=1) + 1e-7).reshape(-1, 1)) + H_s = self.mat2adj(sparse.csr_matrix(H_s)) + H_j = self.mat2adj(sparse.csr_matrix(H_j)) + H_p = self.mat2adj(sparse.csr_matrix(H_p)) + return H_s, H_j, H_p + + def mat2adj(self, mat:sparse.csr_matrix): + adj = dict() + mat = sparse.coo_matrix(mat) + for node1, node2, weight in zip(mat.row, mat.col, mat.data): + node1 = int(node1) + node2 = int(node2) + weight = float(weight) + if weight >= 0.5: + if node1 not in adj: + adj[node1] = set() + adj[node1].add(node2) + return adj + + # 构建编码层 + def _build_embedding(self): + if self.pretrain_embedding is not None: + self.token_embedding = nn.Embedding.from_pretrained( + torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False, + padding_idx=self.pad_token_idx) + else: + self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) + nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) + + self.entity_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0) + nn.init.normal_(self.entity_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.entity_embedding.weight[0], 0) + self.word_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0) + nn.init.normal_(self.word_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) + nn.init.constant_(self.word_embedding.weight[0], 0) + logger.debug('[Build embedding]') + return + + # 构建超图编码层 + def _build_kg_layer(self): + # graph encoder + self.item_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) + self.entity_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) + self.word_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases) + if self.pretrain: + self.item_encoder.load_state_dict(self.pretrain_data['encoder']) + # hypergraph convolution + self.hyper_conv_item = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim) + self.hyper_conv_entity = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim) + self.hyper_conv_word = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim) + # attention type + self.item_attn = MHItemAttention(self.kg_emb_dim, self.mha_n_heads) + # pooling + if self.pooling == 'Attn': + self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim) + self.kg_attn_his = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim) + logger.debug('[Build kg layer]') + return + + # 构建推荐模块 + def _build_recommendation_layer(self): + self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity) + self.rec_loss = nn.CrossEntropyLoss() + logger.debug('[Build recommendation layer]') + return + + # 构建对话模块 + def _build_conversation_layer(self): + self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) + self.entity_to_token = nn.Linear(self.kg_emb_dim, self.token_emb_dim, bias=True) + self.related_encoder = TransformerEncoder( + self.n_heads, + self.n_layers, + self.token_emb_dim, + self.ffn_size, + self.vocab_size, + self.token_embedding, + self.dropout, + self.attention_dropout, + self.relu_dropout, + self.pad_token_idx, + self.learn_positional_embeddings, + self.embeddings_scale, + self.reduction, + self.n_positions + ) + self.context_encoder = TransformerEncoder( + self.n_heads, + self.n_layers, + self.token_emb_dim, + self.ffn_size, + self.vocab_size, + self.token_embedding, + self.dropout, + self.attention_dropout, + self.relu_dropout, + self.pad_token_idx, + self.learn_positional_embeddings, + self.embeddings_scale, + self.reduction, + self.n_positions + ) + self.decoder = TransformerDecoderKG( + self.n_heads, + self.n_layers, + self.token_emb_dim, + self.ffn_size, + self.vocab_size, + self.token_embedding, + self.dropout, + self.attention_dropout, + self.relu_dropout, + self.embeddings_scale, + self.learn_positional_embeddings, + self.pad_token_idx, + self.n_positions + ) + self.user_proj_1 = nn.Linear(self.user_emb_dim, self.user_proj_dim) + self.user_proj_2 = nn.Linear(self.user_proj_dim, self.vocab_size) + self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx) + + self.copy_proj_1 = nn.Linear(2 * self.token_emb_dim, self.token_emb_dim) + self.copy_proj_2 = nn.Linear(self.token_emb_dim, self.vocab_size) + logger.debug('[Build conversation layer]') + return + + # 获取超图 + def _get_hypergraph(self, related, adj): + related_items_set = set() + for related_items in related: + related_items_set.add(related_items) + session_related_items = list(related_items_set) + + hypergraph_nodes, hypergraph_edges, hyper_edge_counter = list(), list(), 0 + for item in session_related_items: + hypergraph_nodes.append(item) + hypergraph_edges.append(hyper_edge_counter) + neighbors = list(adj.get(item, [])) + hypergraph_nodes += neighbors + hypergraph_edges += [hyper_edge_counter] * len(neighbors) + hyper_edge_counter += 1 + hyper_edge_index = torch.tensor([hypergraph_nodes, hypergraph_edges], device=self.device) + return list(set(hypergraph_nodes)), hyper_edge_index + + # 获取聚合 + def _get_embedding(self, hypergraph_items, embedding, tot_sub, adj): + knowledge_embedding_list = [] + for item in hypergraph_items: + sub_graph = [item] + list(adj.get(item, [])) + sub_graph = [tot_sub[item] for item in sub_graph] + sub_graph_embedding = embedding[sub_graph] + sub_graph_embedding = torch.mean(sub_graph_embedding, dim=0) + knowledge_embedding_list.append(sub_graph_embedding) + res_embedding = torch.zeros(1, self.kg_emb_dim).to(self.device) + if len(knowledge_embedding_list) > 0: + res_embedding = torch.stack(knowledge_embedding_list, dim=0) + return res_embedding + + @staticmethod + def flatten(inputs): + outputs = set() + for li in inputs: + for i in li: + outputs.add(i) + return list(outputs) + + # 注意力融合特征向量 + def _attention_and_gating(self, session_embedding, knowledge_embedding, conceptnet_embedding, context_embedding): + related_embedding = torch.cat((session_embedding, knowledge_embedding, conceptnet_embedding), dim=0) + if context_embedding is None: + if self.pooling == 'Attn': + user_repr = self.kg_attn_his(related_embedding) + else: + assert self.pooling == 'Mean' + user_repr = torch.mean(related_embedding, dim=0) + elif self.pooling == 'Attn': + attentive_related_embedding = self.item_attn(related_embedding, context_embedding) + user_repr = self.kg_attn_his(attentive_related_embedding) + user_repr = torch.unsqueeze(user_repr, dim=0) + user_repr = torch.cat((context_embedding, user_repr), dim=0) + user_repr = self.kg_attn(user_repr) + else: + assert self.pooling == 'Mean' + attentive_related_embedding = self.item_attn(related_embedding, context_embedding) + user_repr = torch.mean(attentive_related_embedding, dim=0) + user_repr = torch.unsqueeze(user_repr, dim=0) + user_repr = torch.cat((context_embedding, user_repr), dim=0) + user_repr = torch.mean(user_repr, dim=0) + return user_repr + + def _get_hllm_embedding(self, tot_embedding, hllm_hyper_graph, adj, conv): + hllm_hyper_edge_A = [] + hllm_hyper_edge_B = [] + for idx, hyper_edge in enumerate(hllm_hyper_graph): + hllm_hyper_edge_A += [item for item in hyper_edge] + hllm_hyper_edge_B += [idx] * len(hyper_edge) + + hllm_items = list(set(hllm_hyper_edge_A)) + sub_item2id = {item:idx for idx, item in enumerate(hllm_items)} + sub_embedding = tot_embedding[hllm_items] + + hllm_hyper_edge = [[sub_item2id[item] for item in hllm_hyper_edge_A], hllm_hyper_edge_B] + hllm_hyper_edge = torch.LongTensor(hllm_hyper_edge).to(self.device) + + embedding = conv(sub_embedding, hllm_hyper_edge) + + return embedding + + def encode_gating(self, lists, H_j_adj, H_p_adj, H_s_adj, H_j_gate, H_p_gate, H_s_gate, H_o_gate, Attn, hyper_conv, kg_embedding): + items, item_hyper_edge_index = self._get_hypergraph(lists, H_j_adj) + sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(kg_embedding, items, item_hyper_edge_index, H_j_adj) + H_j_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index) + + items, item_hyper_edge_index = self._get_hypergraph(lists, H_p_adj) + sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(kg_embedding, items, item_hyper_edge_index, H_p_adj) + H_p_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index) + + items, item_hyper_edge_index = self._get_hypergraph(lists, H_s_adj) + sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(kg_embedding, items, item_hyper_edge_index, H_s_adj) + H_s_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index) + + raw_embedding = kg_embedding[lists] + + all_emb = Attn(H_j_embedding, H_p_embedding, H_s_embedding, raw_embedding) + + return all_emb + + def encode_user_repr(self, related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding): + # COLD START + # if len(related_items) or len(related_words) == 0: + # if len(related_entities) == 0: + # user_repr = torch.zeros(self.user_emb_dim, device=self.device) + # elif self.pooling == 'Attn': + # user_repr = tot_entity_embedding[related_entities] + # user_repr = self.kg_attn(user_repr) + # else: + # assert self.pooling == 'Mean' + # user_repr = tot_entity_embedding[related_entities] + # user_repr = torch.mean(user_repr, dim=0) + # return user_repr + + # HiCore + item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(related_items) > 0: + item_embedding = self.encode_gating( + related_items, self.items_H_j, self.items_H_p, self.items_H_s, + self.gate_items_H_j, self.gate_items_H_p, self.gate_items_H_s, self.gate_items_H_o, + self.attn_items, self.hyper_conv_item, tot_item_embedding, + ) + entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(related_entities) > 0: + entity_embedding = self.encode_gating( + related_entities, self.entitys_H_j, self.entitys_H_p, self.entitys_H_s, + self.gate_entitys_H_j, self.gate_entitys_H_p, self.gate_entitys_H_s, self.gate_entitys_H_o, + self.attn_entitys, self.hyper_conv_entity, tot_entity_embedding, + ) + word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(related_words) > 0: + word_embedding = self.encode_gating( + related_words, self.entitys_H_j, self.entitys_H_p, self.entitys_H_s, + self.gate_words_H_j, self.gate_words_H_p, self.gate_words_H_s, self.gate_words_H_o, + self.attn_words, self.hyper_conv_word, tot_word_embedding, + ) + + # 注意力机制 + if len(related_entities) == 0: + user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, None) + else: + context_embedding = tot_entity_embedding[related_entities] + user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, context_embedding) + return user_repr + + def process_hllm(self, hllm_data, id_dict): + res_data = [] + for raw_hyper_grapth in hllm_data: + if not isinstance(raw_hyper_grapth, list): + continue + temp_hyper_grapth = [] + for meta_data in raw_hyper_grapth: + if not isinstance(meta_data, int): + continue + if meta_data not in id_dict: + continue + temp_hyper_grapth.append(id_dict[meta_data]) + res_data.append(temp_hyper_grapth) + return res_data + + # 获取用户编码 + def encode_user(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding): + user_repr_list = [] + for related_items, related_entities, related_words in zip(batch_related_items, batch_related_entities, batch_related_words): + user_repr = self.encode_user_repr(related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding) + user_repr_list.append(user_repr) + user_embedding = torch.stack(user_repr_list, dim=0) + return user_embedding + + # 推荐模块 + def recommend(self, batch, mode): + # 获取数据 + conv_id = batch['conv_id'] + related_item = batch['related_item'] + related_entity = batch['related_entity'] + related_word = batch['related_word'] + item = batch['item'] + item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type) + entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type) + token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type) + + # 获取用户编码 + # start = perf_counter() + user_embedding = self.encode_user( + related_item, + related_entity, + related_word, + item_embedding, + entity_embedding, + token_embedding, + ) # (batch_size, emb_dim) + # print(f"{perf_counter() - start:.2f}") + + # 计算各实体得分 + scores = F.linear(user_embedding, entity_embedding, self.rec_bias.bias) # (batch_size, n_entity) + loss = self.rec_loss(scores, item) + return loss, scores + + def _starts(self, batch_size): + """Return bsz start tokens.""" + return self.START.detach().expand(batch_size, 1) + + def freeze_parameters(self): + freeze_models = [ + self.entity_embedding, + self.token_embedding, + self.item_encoder, + self.entity_encoder, + self.word_encoder, + self.hyper_conv_item, + self.hyper_conv_entity, + self.hyper_conv_word, + self.item_attn, + self.rec_bias + ] + if self.pooling == "Attn": + freeze_models.append(self.kg_attn) + freeze_models.append(self.kg_attn_his) + for model in freeze_models: + for p in model.parameters(): + p.requires_grad = False + + def _before_hyperconv(self, embeddings:torch.FloatTensor, hypergraph_items:List[int], edge_index:torch.LongTensor, adj): + sub_items = [] + edge_index = edge_index.cpu().numpy() + for item in hypergraph_items: + sub_items += [item] + list(adj.get(item, [])) + sub_items = list(set(sub_items)) + tot2sub = {tot:sub for sub, tot in enumerate(sub_items)} + sub_embeddings = embeddings[sub_items] + edge_index = [[tot2sub[v] for v in edge_index[0]], list(edge_index[1])] + sub_edge_index = torch.tensor(edge_index).long() + sub_edge_index = sub_edge_index.to(self.device) + return sub_embeddings, sub_edge_index, tot2sub + + # 获取超图后数据 + def encode_session(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding): + """ + Return: session_repr (batch_size, batch_seq_len, token_emb_dim), mask (batch_size, batch_seq_len) + """ + session_repr_list = [] + for session_related_items, session_related_entities, session_related_words in zip(batch_related_items, batch_related_entities, batch_related_words): + # COLD START + # if len(session_related_items) == 0 or len(session_related_words) == 0: + # if len(session_related_entities) == 0: + # session_repr_list.append(None) + # else: + # session_repr = tot_entity_embedding[session_related_entities] + # session_repr_list.append(session_repr) + # continue + + # HiCore + item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(session_related_items) > 0: + item_embedding = self.encode_gating( + session_related_items, self.items_H_j, self.items_H_p, self.items_H_s, + self.gate_items_H_j, self.gate_items_H_p, self.gate_items_H_s, self.gate_items_H_o, + self.attn_items, self.hyper_conv_item, tot_item_embedding, + ) + entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(session_related_entities) > 0: + entity_embedding = self.encode_gating( + session_related_entities, self.entitys_H_j, self.entitys_H_p, self.entitys_H_s, + self.gate_entitys_H_j, self.gate_entitys_H_p, self.gate_entitys_H_s, self.gate_entitys_H_o, + self.attn_entitys, self.hyper_conv_entity, tot_entity_embedding, + ) + word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device) + if len(session_related_words) > 0: + word_embedding = self.encode_gating( + session_related_words, self.entitys_H_j, self.entitys_H_p, self.entitys_H_s, + self.gate_words_H_j, self.gate_words_H_p, self.gate_words_H_s, self.gate_words_H_o, + self.attn_words, self.hyper_conv_word, tot_word_embedding, + ) + + # 数据拼接 + if len(session_related_entities) == 0: + session_repr = torch.cat((item_embedding, entity_embedding, word_embedding), dim=0) + session_repr_list.append(session_repr) + else: + context_embedding = tot_entity_embedding[session_related_entities] + session_repr = torch.cat((item_embedding, entity_embedding, context_embedding, word_embedding), dim=0) + session_repr_list.append(session_repr) + + batch_seq_len = max([session_repr.size(0) for session_repr in session_repr_list if session_repr is not None]) + mask_list = [] + for i in range(len(session_repr_list)): + if session_repr_list[i] is None: + mask_list.append([False] * batch_seq_len) + zero_repr = torch.zeros((batch_seq_len, self.kg_emb_dim), device=self.device, dtype=torch.float) + session_repr_list[i] = zero_repr + else: + mask_list.append([False] * (batch_seq_len - session_repr_list[i].size(0)) + [True] * session_repr_list[i].size(0)) + zero_repr = torch.zeros((batch_seq_len - session_repr_list[i].size(0), self.kg_emb_dim), + device=self.device, dtype=torch.float) + session_repr_list[i] = torch.cat((zero_repr, session_repr_list[i]), dim=0) + + session_repr_embedding = torch.stack(session_repr_list, dim=0) + session_repr_embedding = self.entity_to_token(session_repr_embedding) + # print("session_repr_embedding.shape", session_repr_embedding.shape) # [6, 7, 300] + return session_repr_embedding, torch.tensor(mask_list, device=self.device, dtype=torch.bool) + + # 生成对话 + def decode_forced(self, related_encoder_state, context_encoder_state, session_state, user_embedding, resp): + bsz = resp.size(0) + seqlen = resp.size(1) + inputs = resp.narrow(1, 0, seqlen - 1) + inputs = torch.cat([self._starts(bsz), inputs], 1) + latent, _ = self.decoder(inputs, related_encoder_state, context_encoder_state, session_state) + token_logits = F.linear(latent, self.token_embedding.weight) + user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) + + user_latent = self.entity_to_token(user_embedding) + user_latent = user_latent.unsqueeze(1).expand(-1, seqlen, -1) + copy_latent = torch.cat((user_latent, latent), dim=-1) + copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent))) + if self.dataset == 'HReDial': + copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial + sum_logits = token_logits + user_logits + copy_logits + _, preds = sum_logits.max(dim=-1) + return sum_logits, preds + + # 生成对话 - test + def decode_greedy(self, related_encoder_state, context_encoder_state, session_state, user_embedding): + bsz = context_encoder_state[0].shape[0] + xs = self._starts(bsz) + incr_state = None + logits = [] + for i in range(self.longest_label): + scores, incr_state = self.decoder(xs, related_encoder_state, context_encoder_state, session_state, incr_state) # incr_state is always None + scores = scores[:, -1:, :] + token_logits = F.linear(scores, self.token_embedding.weight) + user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1) + + user_latent = self.entity_to_token(user_embedding) + user_latent = user_latent.unsqueeze(1).expand(-1, 1, -1) + copy_latent = torch.cat((user_latent, scores), dim=-1) + copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent))) + if self.dataset == 'HReDial': + copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial + sum_logits = token_logits + user_logits + copy_logits + probs, preds = sum_logits.max(dim=-1) + logits.append(scores) + xs = torch.cat([xs, preds], dim=1) + # check if everyone has generated an end token + all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz + if all_finished: + break + logits = torch.cat(logits, 1) + return logits, xs + + # 对话模块训练 + def converse(self, batch, mode): + # 获取数据 + conv_id = batch['conv_id'] + related_item = batch['related_item'] + related_entity = batch['related_entity'] + related_word = batch['related_word'] + response = batch['response'] + + related_tokens = batch['related_tokens'] + context_tokens = batch['context_tokens'] + + item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type) + entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type) + token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type) + + # 获取对话编码 + session_state = self.encode_session( + related_item, + related_entity, + related_word, + item_embedding, + entity_embedding, + token_embedding, + ) + + # 获取用户编码 + # start = perf_counter() + user_embedding = self.encode_user( + related_item, + related_entity, + related_word, + item_embedding, + entity_embedding, + token_embedding, + ) # (batch_size, emb_dim) + + # 获取 X_c、X_h + related_encoder_state = self.related_encoder(related_tokens) + context_encoder_state = self.context_encoder(context_tokens) + + # 对话生成 + if mode != 'test': + self.longest_label = max(self.longest_label, response.shape[1]) + logits, preds = self.decode_forced(related_encoder_state, context_encoder_state, session_state, user_embedding, response) + logits = logits.view(-1, logits.shape[-1]) + labels = response.view(-1) + return self.conv_loss(logits, labels), preds + else: + _, preds = self.decode_greedy(related_encoder_state, context_encoder_state, session_state, user_embedding) + return preds + + # 推荐模块和对话模块分开训练 + def forward(self, batch, mode, stage): + if len(self.gpu) >= 2: + self.edge_idx = self.edge_idx.cuda(torch.cuda.current_device()) + self.edge_type = self.edge_type.cuda(torch.cuda.current_device()) + if stage == "conv": + return self.converse(batch, mode) + if stage == "rec": + # start = perf_counter() + res = self.recommend(batch, mode) + # print(f"{perf_counter() - start:.2f}") + return res \ No newline at end of file diff --git a/HiCore/crslab/model/pretrained_models.py b/HiCore/crslab/model/pretrained_models.py new file mode 100644 index 0000000..33c20d6 --- /dev/null +++ b/HiCore/crslab/model/pretrained_models.py @@ -0,0 +1,64 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/1/6 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2021/1/7 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.download import DownloadableFile + +"""Download links of pretrain models. + +Now we provide the following models: + +- `BERT`_: zh, en +- `GPT2`_: zh, en + +.. _BERT: + https://www.aclweb.org/anthology/N19-1423/ +.. _GPT2: + https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf + +""" + +resources = { + 'bert': { + 'zh': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1', + 'bert_zh.zip', + 'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3' + ) + }, + 'en': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1', + 'bert_en.zip', + '61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3' + ) + }, + }, + 'gpt2': { + 'zh': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1', + 'gpt2_zh.zip', + '5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43' + ) + }, + 'en': { + 'version': '0.1', + 'file': DownloadableFile( + 'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1', + 'gpt2_en.zip', + '518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e' + ) + } + } +} diff --git a/HiCore/crslab/model/utils/__init__.py b/HiCore/crslab/model/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HiCore/crslab/model/utils/functions.py b/HiCore/crslab/model/utils/functions.py new file mode 100644 index 0000000..8c1345e --- /dev/null +++ b/HiCore/crslab/model/utils/functions.py @@ -0,0 +1,37 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/11/26 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2020/11/16 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +import torch + + +def edge_to_pyg_format(edge, type='RGCN'): + if type == 'RGCN': + edge_sets = torch.as_tensor(edge, dtype=torch.long) + edge_idx = edge_sets[:, :2].t() + edge_type = edge_sets[:, 2] + return edge_idx, edge_type + elif type == 'GCN': + edge_set = [[co[0] for co in edge], [co[1] for co in edge]] + return torch.as_tensor(edge_set, dtype=torch.long) + else: + raise NotImplementedError('type {} has not been implemented', type) + + +def sort_for_packed_sequence(lengths: torch.Tensor): + """ + :param lengths: 1D array of lengths + :return: sorted_lengths (lengths in descending order), sorted_idx (indices to sort), rev_idx (indices to retrieve original order) + + """ + sorted_idx = torch.argsort(lengths, descending=True) # idx to sort by length + rev_idx = torch.argsort(sorted_idx) # idx to retrieve original order + sorted_lengths = lengths[sorted_idx] + + return sorted_lengths, sorted_idx, rev_idx diff --git a/HiCore/crslab/model/utils/modules/__init__.py b/HiCore/crslab/model/utils/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HiCore/crslab/model/utils/modules/attention.py b/HiCore/crslab/model/utils/modules/attention.py new file mode 100644 index 0000000..870186e --- /dev/null +++ b/HiCore/crslab/model/utils/modules/attention.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SelfAttentionBatch(nn.Module): + def __init__(self, dim, da, alpha=0.2, dropout=0.5): + super(SelfAttentionBatch, self).__init__() + self.dim = dim + self.da = da + self.alpha = alpha + self.dropout = dropout + self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + nn.init.xavier_uniform_(self.b.data, gain=1.414) + + def forward(self, h): + # h: (N, dim) + e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1) + attention = F.softmax(e, dim=0) # (N) + return torch.matmul(attention, h) # (dim) + + +class SelfAttentionSeq(nn.Module): + def __init__(self, dim, da, alpha=0.2, dropout=0.5): + super(SelfAttentionSeq, self).__init__() + self.dim = dim + self.da = da + self.alpha = alpha + self.dropout = dropout + self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True) + nn.init.xavier_uniform_(self.a.data, gain=1.414) + nn.init.xavier_uniform_(self.b.data, gain=1.414) + + def forward(self, h, mask=None, return_logits=False): + """ + For the padding tokens, its corresponding mask is True + if mask==[1, 1, 1, ...] + """ + # h: (batch, seq_len, dim), mask: (batch, seq_len) + e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) # (batch, seq_len, 1) + if mask is not None: + full_mask = -1e30 * mask.float() + batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1) # for all padding one, the mask=0 + mask = full_mask * batch_mask + e += mask.unsqueeze(-1) + attention = F.softmax(e, dim=1) # (batch, seq_len, 1) + # (batch, dim) + if return_logits: + return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1) + else: + return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1) diff --git a/HiCore/crslab/model/utils/modules/transformer.py b/HiCore/crslab/model/utils/modules/transformer.py new file mode 100644 index 0000000..f571b91 --- /dev/null +++ b/HiCore/crslab/model/utils/modules/transformer.py @@ -0,0 +1,471 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +"""Near infinity, useful as a large penalty for scoring when inf is bad.""" +NEAR_INF = 1e20 +NEAR_INF_FP16 = 65504 + + +def neginf(dtype): + """Returns a representable finite number near -inf for a dtype.""" + if dtype is torch.float16: + return -NEAR_INF_FP16 + else: + return -NEAR_INF + + +def _create_selfattn_mask(x): + # figure out how many timestamps we need + bsz = x.size(0) + time = x.size(1) + # make sure that we don't look into the future + mask = torch.tril(x.new(time, time).fill_(1)) + # broadcast across batch + mask = mask.unsqueeze(0).expand(bsz, -1, -1) + return mask + + +def create_position_codes(n_pos, dim, out): + position_enc = np.array([ + [pos / np.power(10000, 2 * j / dim) for j in range(dim // 2)] + for pos in range(n_pos) + ]) + + out.data[:, 0::2] = torch.as_tensor(np.sin(position_enc)) + out.data[:, 1::2] = torch.as_tensor(np.cos(position_enc)) + out.detach_() + out.requires_grad = False + + +def _normalize(tensor, norm_layer): + """Broadcast layer norm""" + size = tensor.size() + return norm_layer(tensor.view(-1, size[-1])).view(size) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_heads, dim, dropout=.0): + super(MultiHeadAttention, self).__init__() + self.n_heads = n_heads + self.dim = dim + + self.attn_dropout = nn.Dropout(p=dropout) # --attention-dropout + self.q_lin = nn.Linear(dim, dim) + self.k_lin = nn.Linear(dim, dim) + self.v_lin = nn.Linear(dim, dim) + # TODO: merge for the initialization step + nn.init.xavier_normal_(self.q_lin.weight) + nn.init.xavier_normal_(self.k_lin.weight) + nn.init.xavier_normal_(self.v_lin.weight) + # and set biases to 0 + self.out_lin = nn.Linear(dim, dim) + + nn.init.xavier_normal_(self.out_lin.weight) + + def forward(self, query, key=None, value=None, mask=None): + # Input is [B, query_len, dim] + # Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn) + batch_size, query_len, dim = query.size() + assert dim == self.dim, \ + f'Dimensions do not match: {dim} query vs {self.dim} configured' + assert mask is not None, 'Mask is None, please specify a mask' + n_heads = self.n_heads + dim_per_head = dim // n_heads + scale = math.sqrt(dim_per_head) + + def prepare_head(tensor): + # input is [batch_size, seq_len, n_heads * dim_per_head] + # output is [batch_size * n_heads, seq_len, dim_per_head] + bsz, seq_len, _ = tensor.size() + tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) + tensor = tensor.transpose(1, 2).contiguous().view( + batch_size * n_heads, + seq_len, + dim_per_head + ) + return tensor + + # q, k, v are the transformed values + if key is None and value is None: + # self attention + key = value = query + elif value is None: + # key and value are the same, but query differs + # self attention + value = key + _, key_len, dim = key.size() + + q = prepare_head(self.q_lin(query)) + k = prepare_head(self.k_lin(key)) + v = prepare_head(self.v_lin(value)) + + dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) + # [B * n_heads, query_len, key_len] + attn_mask = ( + (mask == 0) + .view(batch_size, 1, -1, key_len) + .repeat(1, n_heads, 1, 1) + .expand(batch_size, n_heads, query_len, key_len) + .view(batch_size * n_heads, query_len, key_len) + ) + assert attn_mask.shape == dot_prod.shape + dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) + + attn_weights = F.softmax(dot_prod, dim=-1).type_as(query) + attn_weights = self.attn_dropout(attn_weights) # --attention-dropout + + attentioned = attn_weights.bmm(v) + attentioned = ( + attentioned.type_as(query) + .view(batch_size, n_heads, query_len, dim_per_head) + .transpose(1, 2).contiguous() + .view(batch_size, query_len, dim) + ) + + out = self.out_lin(attentioned) + + return out + + +class TransformerFFN(nn.Module): + def __init__(self, dim, dim_hidden, relu_dropout=.0): + super(TransformerFFN, self).__init__() + self.relu_dropout = nn.Dropout(p=relu_dropout) + self.lin1 = nn.Linear(dim, dim_hidden) + self.lin2 = nn.Linear(dim_hidden, dim) + nn.init.xavier_uniform_(self.lin1.weight) + nn.init.xavier_uniform_(self.lin2.weight) + # TODO: initialize biases to 0 + + def forward(self, x): + x = F.relu(self.lin1(x)) + x = self.relu_dropout(x) # --relu-dropout + x = self.lin2(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + n_heads, + embedding_size, + ffn_size, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + ): + super().__init__() + self.dim = embedding_size + self.ffn_dim = ffn_size + self.attention = MultiHeadAttention( + n_heads, embedding_size, + dropout=attention_dropout, # --attention-dropout + ) + self.norm1 = nn.LayerNorm(embedding_size) + self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout) + self.norm2 = nn.LayerNorm(embedding_size) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, tensor, mask): + tensor = tensor + self.dropout(self.attention(tensor, mask=mask)) + tensor = _normalize(tensor, self.norm1) + tensor = tensor + self.dropout(self.ffn(tensor)) + tensor = _normalize(tensor, self.norm2) + tensor *= mask.unsqueeze(-1).type_as(tensor) + return tensor + + +class TransformerEncoder(nn.Module): + """ + Transformer encoder module. + + :param int n_heads: the number of multihead attention heads. + :param int n_layers: number of transformer layers. + :param int embedding_size: the embedding sizes. Must be a multiple of n_heads. + :param int ffn_size: the size of the hidden layer in the FFN + :param embedding: an embedding matrix for the bottom layer of the transformer. + If none, one is created for this encoder. + :param float dropout: Dropout used around embeddings and before layer + layer normalizations. This is used in Vaswani 2017 and works well on + large datasets. + :param float attention_dropout: Dropout performed after the multhead attention + softmax. This is not used in Vaswani 2017. + :param float relu_dropout: Dropout used after the ReLU in the FFN. Not used + in Vaswani 2017, but used in Tensor2Tensor. + :param int padding_idx: Reserved padding index in the embeddings matrix. + :param bool learn_positional_embeddings: If off, sinusoidal embeddings are + used. If on, position embeddings are learned from scratch. + :param bool embeddings_scale: Scale embeddings relative to their dimensionality. + Found useful in fairseq. + :param bool reduction: If true, returns the mean vector for the entire encoding + sequence. + :param int n_positions: Size of the position embeddings matrix. + """ + + def __init__( + self, + n_heads, + n_layers, + embedding_size, + ffn_size, + vocabulary_size, + embedding=None, + dropout=0.0, + attention_dropout=0.0, + relu_dropout=0.0, + padding_idx=0, + learn_positional_embeddings=False, + embeddings_scale=False, + reduction=True, + n_positions=1024 + ): + super(TransformerEncoder, self).__init__() + + self.embedding_size = embedding_size + self.ffn_size = ffn_size + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = embedding_size + self.embeddings_scale = embeddings_scale + self.reduction = reduction + self.padding_idx = padding_idx + # this is --dropout, not --relu-dropout or --attention-dropout + self.dropout = nn.Dropout(dropout) + self.out_dim = embedding_size + assert embedding_size % n_heads == 0, \ + 'Transformer embedding size must be a multiple of n_heads' + + # check input formats: + if embedding is not None: + assert ( + embedding_size is None or embedding_size == embedding.weight.shape[1] + ), "Embedding dim must match the embedding size." + + if embedding is not None: + self.embeddings = embedding + else: + assert False + assert padding_idx is not None + self.embeddings = nn.Embedding( + vocabulary_size, embedding_size, padding_idx=padding_idx + ) + nn.init.normal_(self.embeddings.weight, 0, embedding_size ** -0.5) + + # create the positional embeddings + self.position_embeddings = nn.Embedding(n_positions, embedding_size) + if not learn_positional_embeddings: + create_position_codes( + n_positions, embedding_size, out=self.position_embeddings.weight + ) + else: + nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5) + + # build the model + self.layers = nn.ModuleList() + for _ in range(self.n_layers): + self.layers.append(TransformerEncoderLayer( + n_heads, embedding_size, ffn_size, + attention_dropout=attention_dropout, + relu_dropout=relu_dropout, + dropout=dropout, + )) + + def forward(self, input): + """ + input data is a FloatTensor of shape [batch, seq_len, dim] + mask is a ByteTensor of shape [batch, seq_len], filled with 1 when + inside the sequence and 0 outside. + """ + mask = input != self.padding_idx + positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0) + tensor = self.embeddings(input) + if self.embeddings_scale: + tensor = tensor * np.sqrt(self.dim) + tensor = tensor + self.position_embeddings(positions).expand_as(tensor) + # --dropout on the embeddings + tensor = self.dropout(tensor) + + tensor *= mask.unsqueeze(-1).type_as(tensor) + for i in range(self.n_layers): + tensor = self.layers[i](tensor, mask) + + if self.reduction: + divisor = mask.type_as(tensor).sum(dim=1).unsqueeze(-1).clamp(min=1e-7) + output = tensor.sum(dim=1) / divisor + return output + else: + output = tensor + return output, mask + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + n_heads, + embedding_size, + ffn_size, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + ): + super().__init__() + self.dim = embedding_size + self.ffn_dim = ffn_size + self.dropout = nn.Dropout(p=dropout) + + self.self_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm1 = nn.LayerNorm(embedding_size) + + self.encoder_attention = MultiHeadAttention( + n_heads, embedding_size, dropout=attention_dropout + ) + self.norm2 = nn.LayerNorm(embedding_size) + + self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout) + self.norm3 = nn.LayerNorm(embedding_size) + + def forward(self, x, encoder_output, encoder_mask): + decoder_mask = self._create_selfattn_mask(x) + # first self attn + residual = x + # don't peak into the future! + x = self.self_attention(query=x, mask=decoder_mask) + x = self.dropout(x) # --dropout + x = x + residual + x = _normalize(x, self.norm1) + + residual = x + x = self.encoder_attention( + query=x, + key=encoder_output, + value=encoder_output, + mask=encoder_mask + ) + x = self.dropout(x) # --dropout + x = residual + x + x = _normalize(x, self.norm2) + + # finally the ffn + residual = x + x = self.ffn(x) + x = self.dropout(x) # --dropout + x = residual + x + x = _normalize(x, self.norm3) + + return x + + def _create_selfattn_mask(self, x): + # figure out how many timestamps we need + bsz = x.size(0) + time = x.size(1) + # make sure that we don't look into the future + mask = torch.tril(x.new(time, time).fill_(1)) + # broadcast across batch + mask = mask.unsqueeze(0).expand(bsz, -1, -1) + return mask + + +class TransformerDecoder(nn.Module): + """ + Transformer Decoder layer. + + :param int n_heads: the number of multihead attention heads. + :param int n_layers: number of transformer layers. + :param int embedding_size: the embedding sizes. Must be a multiple of n_heads. + :param int ffn_size: the size of the hidden layer in the FFN + :param embedding: an embedding matrix for the bottom layer of the transformer. + If none, one is created for this encoder. + :param float dropout: Dropout used around embeddings and before layer + layer normalizations. This is used in Vaswani 2017 and works well on + large datasets. + :param float attention_dropout: Dropout performed after the multhead attention + softmax. This is not used in Vaswani 2017. + :param int padding_idx: Reserved padding index in the embeddings matrix. + :param bool learn_positional_embeddings: If off, sinusoidal embeddings are + used. If on, position embeddings are learned from scratch. + :param bool embeddings_scale: Scale embeddings relative to their dimensionality. + Found useful in fairseq. + :param int n_positions: Size of the position embeddings matrix. + """ + + def __init__( + self, + n_heads, + n_layers, + embedding_size, + ffn_size, + vocabulary_size, + embedding=None, + dropout=0.0, + attention_dropout=0.0, + relu_dropout=0.0, + embeddings_scale=True, + learn_positional_embeddings=False, + padding_idx=None, + n_positions=1024, + ): + super().__init__() + self.embedding_size = embedding_size + self.ffn_size = ffn_size + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = embedding_size + self.embeddings_scale = embeddings_scale + self.dropout = nn.Dropout(p=dropout) # --dropout + + self.out_dim = embedding_size + assert embedding_size % n_heads == 0, \ + 'Transformer embedding size must be a multiple of n_heads' + + self.embeddings = embedding + + # create the positional embeddings + self.position_embeddings = nn.Embedding(n_positions, embedding_size) + if not learn_positional_embeddings: + create_position_codes( + n_positions, embedding_size, out=self.position_embeddings.weight + ) + else: + nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5) + + # build the model + self.layers = nn.ModuleList() + for _ in range(self.n_layers): + self.layers.append(TransformerDecoderLayer( + n_heads, embedding_size, ffn_size, + attention_dropout=attention_dropout, + relu_dropout=relu_dropout, + dropout=dropout, + )) + + def forward(self, input, encoder_state, incr_state=None): + encoder_output, encoder_mask = encoder_state + + seq_len = input.shape[1] + positions = input.new_empty(seq_len).long() + positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len) + tensor = self.embeddings(input) + if self.embeddings_scale: + tensor = tensor * np.sqrt(self.dim) + tensor = tensor + self.position_embeddings(positions).expand_as(tensor) + tensor = self.dropout(tensor) # --dropout + + for layer in self.layers: + tensor = layer(tensor, encoder_output, encoder_mask) + + return tensor, None diff --git a/HiCore/crslab/quick_start/__init__.py b/HiCore/crslab/quick_start/__init__.py new file mode 100644 index 0000000..12b84a5 --- /dev/null +++ b/HiCore/crslab/quick_start/__init__.py @@ -0,0 +1 @@ +from .quick_start import run_crslab diff --git a/HiCore/crslab/quick_start/quick_start.py b/HiCore/crslab/quick_start/quick_start.py new file mode 100644 index 0000000..2ebc37d --- /dev/null +++ b/HiCore/crslab/quick_start/quick_start.py @@ -0,0 +1,82 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/1/8 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +# UPDATE +# @Time : 2021/1/9 +# @Author : Xiaolei Wang +# @email : wxl1999@foxmail.com + +from crslab.config import Config +from crslab.data import get_dataset, get_dataloader +from crslab.system import get_system + + +def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False, + interact=False, debug=False): + """A fast running api, which includes the complete process of training and testing models on specified datasets. + + Args: + config (Config or str): an instance of ``Config`` or path to the config file, + which should be in ``yaml`` format. You can use default config provided in the `Github repo`_, + or write it by yourself. + save_data (bool): whether to save data. Defaults to False. + restore_data (bool): whether to restore data. Defaults to False. + save_system (bool): whether to save system. Defaults to False. + restore_system (bool): whether to restore system. Defaults to False. + interact (bool): whether to interact with the system. Defaults to False. + debug (bool): whether to debug the system. Defaults to False. + + .. _Github repo: + https://github.com/RUCAIBox/CRSLab + + """ + # dataset & dataloader + if isinstance(config['tokenize'], str): + CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data) + side_data = CRS_dataset.side_data + vocab = CRS_dataset.vocab + + train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab) + valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab) + test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab) + else: + tokenized_dataset = {} + train_dataloader = {} + valid_dataloader = {} + test_dataloader = {} + vocab = {} + side_data = {} + + for task, tokenize in config['tokenize'].items(): + if tokenize in tokenized_dataset: + dataset = tokenized_dataset[tokenize] + else: + dataset = get_dataset(config, tokenize, restore_data, save_data) + tokenized_dataset[tokenize] = dataset + train_data = dataset.train_data + valid_data = dataset.valid_data + test_data = dataset.test_data + train_review = dataset.train_review + valid_review = dataset.valid_review + test_review = dataset.test_review + side_data[task] = dataset.side_data + vocab[task] = dataset.vocab + review_id2entity = dataset.review_id2entity + + train_dataloader[task] = get_dataloader(config, train_data, train_review, vocab[task], review_id2entity) + valid_dataloader[task] = get_dataloader(config, valid_data, valid_review, vocab[task], review_id2entity) + test_dataloader[task] = get_dataloader(config, test_data, test_review, vocab[task], review_id2entity) + + # system + CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system, + interact, debug) + if interact: + CRS.interact() + else: + CRS.fit() + if save_system: + CRS.save_model() + + return \ No newline at end of file diff --git a/HiCore/crslab/system/__init__.py b/HiCore/crslab/system/__init__.py new file mode 100644 index 0000000..cd8354d --- /dev/null +++ b/HiCore/crslab/system/__init__.py @@ -0,0 +1,36 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/29 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +# @Time : 2021/10/6 +# @Author : Zhipeng Zhao +# @email : oran_official@outlook.com + + +from loguru import logger +from .hicore import HiCoreSystem + +system_register_table = { + 'HiCore': HiCoreSystem, +} + + +def get_system(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False, + interact=False, debug=False): + """ + return the system class + """ + model_name = opt['model_name'] + if model_name in system_register_table: + system = system_register_table[model_name](opt, train_dataloader, valid_dataloader, test_dataloader, vocab, + side_data, restore_system, interact, debug) + logger.info(f'[Build system {model_name}]') + return system + else: + raise NotImplementedError('The system with model [{}] in dataset [{}] has not been implemented'. + format(model_name, opt['dataset'])) diff --git a/HiCore/crslab/system/base.py b/HiCore/crslab/system/base.py new file mode 100644 index 0000000..75f06ea --- /dev/null +++ b/HiCore/crslab/system/base.py @@ -0,0 +1,290 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2021/1/9 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +# UPDATE: +# @Time : 2021/11/5 +# @Author : Zhipeng Zhao +# @Email : oran_official@outlook.com + +import os +from abc import ABC, abstractmethod +import numpy as np +import random +import torch +from loguru import logger +from torch import optim +from transformers import Adafactor + +from crslab.config import SAVE_PATH +from crslab.evaluator import get_evaluator +from crslab.evaluator.metrics.base import AverageMetric +from crslab.model import get_model +from crslab.system.utils import lr_scheduler +from crslab.system.utils.functions import compute_grad_norm + +optim_class = {} +optim_class.update({k: v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper()}) +optim_class.update({'AdamW': optim.AdamW, 'Adafactor': Adafactor}) +lr_scheduler_class = {k: v for k, v in lr_scheduler.__dict__.items() if not k.startswith('__') and k[0].isupper()} +transformers_tokenizer = ('bert', 'gpt2') + + +class BaseSystem(ABC): + """Base class for all system""" + + def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False, + interact=False, debug=False): + """ + + Args: + opt (dict): Indicating the hyper parameters. + train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset. + valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset. + test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset. + vocab (dict): Indicating the vocabulary. + side_data (dict): Indicating the side data. + restore_system (bool, optional): Indicating if we store system after training. Defaults to False. + interact (bool, optional): Indicating if we interact with system. Defaults to False. + debug (bool, optional): Indicating if we train in debug mode. Defaults to False. + + """ + self.opt = opt + if opt["gpu"] == [-1]: + self.device = torch.device('cpu') + elif len(opt["gpu"]) == 1: + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + # seed + if 'seed' in opt: + seed = int(opt['seed']) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + logger.info(f'[Set seed] {seed}') + # data + if debug: + self.train_dataloader = valid_dataloader + self.valid_dataloader = valid_dataloader + self.test_dataloader = test_dataloader + else: + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.test_dataloader = test_dataloader + self.vocab = vocab + self.side_data = side_data + # model + if 'model' in opt: + self.model = get_model(opt, opt['model'], self.device, vocab, side_data).to(self.device) + else: + if 'rec_model' in opt: + self.rec_model = get_model(opt, opt['rec_model'], self.device, vocab['rec'], side_data['rec']).to( + self.device) + if 'conv_model' in opt: + self.conv_model = get_model(opt, opt['conv_model'], self.device, vocab['conv'], side_data['conv']).to( + self.device) + if 'policy_model' in opt: + self.policy_model = get_model(opt, opt['policy_model'], self.device, vocab['policy'], + side_data['policy']).to(self.device) + model_file_name = opt.get('model_file', f'{opt["model_name"]}.pth') + self.model_file = os.path.join(SAVE_PATH, model_file_name) + if restore_system: + self.restore_model() + + if not interact: + self.evaluator = get_evaluator('standard', opt['dataset'], opt['rankfile']) + + def init_optim(self, opt, parameters): + self.optim_opt = opt + parameters = list(parameters) + if isinstance(parameters[0], dict): + for i, d in enumerate(parameters): + parameters[i]['params'] = list(d['params']) + + # gradient acumulation + self.update_freq = opt.get('update_freq', 1) + self._number_grad_accum = 0 + + self.gradient_clip = opt.get('gradient_clip', -1) + + self.build_optimizer(parameters) + self.build_lr_scheduler() + + if isinstance(parameters[0], dict): + self.parameters = [] + for d in parameters: + self.parameters.extend(d['params']) + else: + self.parameters = parameters + + # early stop + self.need_early_stop = self.optim_opt.get('early_stop', False) + if self.need_early_stop: + logger.debug('[Enable early stop]') + self.reset_early_stop_state() + + def build_optimizer(self, parameters): + optimizer_opt = self.optim_opt['optimizer'] + optimizer = optimizer_opt.pop('name') + self.optimizer = optim_class[optimizer](parameters, **optimizer_opt) + logger.info(f"[Build optimizer: {optimizer}]") + + def build_lr_scheduler(self): + """ + Create the learning rate scheduler, and assign it to self.scheduler. This + scheduler will be updated upon a call to receive_metrics. May also create + self.warmup_scheduler, if appropriate. + + :param state_dict states: Possible state_dict provided by model + checkpoint, for restoring LR state + :param bool hard_reset: If true, the LR scheduler should ignore the + state dictionary. + """ + if self.optim_opt.get('lr_scheduler', None): + lr_scheduler_opt = self.optim_opt['lr_scheduler'] + lr_scheduler = lr_scheduler_opt.pop('name') + self.scheduler = lr_scheduler_class[lr_scheduler](self.optimizer, **lr_scheduler_opt) + logger.info(f"[Build scheduler {lr_scheduler}]") + + def reset_early_stop_state(self): + self.best_valid = None + self.drop_cnt = 0 + self.impatience = self.optim_opt.get('impatience', 3) + if self.optim_opt['stop_mode'] == 'max': + self.stop_mode = 1 + elif self.optim_opt['stop_mode'] == 'min': + self.stop_mode = -1 + else: + raise + logger.debug('[Reset early stop state]') + + @abstractmethod + def fit(self): + """fit the whole system""" + pass + + @abstractmethod + def step(self, batch, stage, mode): + """calculate loss and prediction for batch data under certrain stage and mode + + Args: + batch (dict or tuple): batch data + stage (str): recommendation/policy/conversation etc. + mode (str): train/valid/test + """ + pass + + def backward(self, loss): + """empty grad, backward loss and update params + + Args: + loss (torch.Tensor): + """ + self._zero_grad() + + if self.update_freq > 1: + self._number_grad_accum = (self._number_grad_accum + 1) % self.update_freq + loss /= self.update_freq + loss.backward(loss.clone().detach()) + + self._update_params() + + def _zero_grad(self): + if self._number_grad_accum != 0: + # if we're accumulating gradients, don't actually zero things out yet. + return + self.optimizer.zero_grad() + + def _update_params(self): + if self.update_freq > 1: + # we're doing gradient accumulation, so we don't only want to step + # every N updates instead + # self._number_grad_accum is updated in backward function + if self._number_grad_accum != 0: + return + + if self.gradient_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.parameters, self.gradient_clip + ) + self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm)) + self.evaluator.optim_metrics.add( + 'grad clip ratio', + AverageMetric(float(grad_norm > self.gradient_clip)), + ) + else: + grad_norm = compute_grad_norm(self.parameters) + self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm)) + + self.optimizer.step() + + if hasattr(self, 'scheduler'): + self.scheduler.train_step() + + def adjust_lr(self, metric=None): + """adjust learning rate w/o metric by scheduler + + Args: + metric (optional): Defaults to None. + """ + if not hasattr(self, 'scheduler') or self.scheduler is None: + return + self.scheduler.valid_step(metric) + logger.debug('[Adjust learning rate after valid epoch]') + + def early_stop(self, metric): + if not self.need_early_stop: + return False + if self.best_valid is None or metric * self.stop_mode > self.best_valid * self.stop_mode: + self.best_valid = metric + self.drop_cnt = 0 + logger.info('[Get new best model]') + return False + else: + self.drop_cnt += 1 + if self.drop_cnt >= self.impatience: + logger.info('[Early stop]') + return True + + def save_model(self): + r"""Store the model parameters.""" + state = {} + if hasattr(self, 'model'): + state['model_state_dict'] = self.model.state_dict() + if hasattr(self, 'rec_model'): + state['rec_state_dict'] = self.rec_model.state_dict() + if hasattr(self, 'conv_model'): + state['conv_state_dict'] = self.conv_model.state_dict() + if hasattr(self, 'policy_model'): + state['policy_state_dict'] = self.policy_model.state_dict() + + os.makedirs(SAVE_PATH, exist_ok=True) + torch.save(state, self.model_file) + logger.info(f'[Save model into {self.model_file}]') + + def restore_model(self): + r"""Store the model parameters.""" + if not os.path.exists(self.model_file): + raise ValueError(f'Saved model [{self.model_file}] does not exist') + checkpoint = torch.load(self.model_file, map_location=self.device) + if hasattr(self, 'model'): + self.model.load_state_dict(checkpoint['model_state_dict']) + if hasattr(self, 'rec_model'): + self.rec_model.load_state_dict(checkpoint['rec_state_dict']) + if hasattr(self, 'conv_model'): + self.conv_model.load_state_dict(checkpoint['conv_state_dict']) + if hasattr(self, 'policy_model'): + self.policy_model.load_state_dict(checkpoint['policy_state_dict']) + logger.info(f'[Restore model from {self.model_file}]') + + @abstractmethod + def interact(self): + pass diff --git a/HiCore/crslab/system/hicore.py b/HiCore/crslab/system/hicore.py new file mode 100644 index 0000000..36e17c5 --- /dev/null +++ b/HiCore/crslab/system/hicore.py @@ -0,0 +1,176 @@ +# -*- encoding: utf-8 -*- +# @Time : 2021/5/26 +# @Author : Chenzhan Shang +# @email : czshang@outlook.com + +import os +import json +from time import perf_counter +import torch +import pickle as pkl +from loguru import logger + +from crslab.evaluator.metrics.base import AverageMetric +from crslab.evaluator.metrics.gen import PPLMetric +from crslab.system.base import BaseSystem +from crslab.system.utils.functions import ind2txt + + +class HiCoreSystem(BaseSystem): + """This is the system for KBRD model""" + + def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False, + interact=False, debug=False): + """ + + Args: + opt (dict): Indicating the hyper parameters. + train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset. + valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset. + test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset. + vocab (dict): Indicating the vocabulary. + side_data (dict): Indicating the side data. + restore_system (bool, optional): Indicating if we store system after training. Defaults to False. + interact (bool, optional): Indicating if we interact with system. Defaults to False. + debug (bool, optional): Indicating if we train in debug mode. Defaults to False. + + """ + super(HiCoreSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, + restore_system, interact, debug) + + self.ind2tok = vocab['ind2tok'] + self.end_token_idx = vocab['tok2ind']['__end__'] + self.item_ids = side_data['item_entity_ids'] + + self.rec_optim_opt = opt['rec'] + self.conv_optim_opt = opt['conv'] + self.rec_epoch = self.rec_optim_opt['epoch'] + self.conv_epoch = self.conv_optim_opt['epoch'] + self.rec_batch_size = self.rec_optim_opt['batch_size'] + self.conv_batch_size = self.conv_optim_opt['batch_size'] + + def rec_evaluate(self, rec_predict, item_label): + rec_predict = rec_predict.cpu() + rec_predict = rec_predict[:, self.item_ids] + _, rec_ranks = torch.topk(rec_predict, 50, dim=-1) + rec_ranks = rec_ranks.tolist() + item_label = item_label.tolist() + # start = perf_counter() + for rec_rank, label in zip(rec_ranks, item_label): + label = self.item_ids.index(label) + self.evaluator.rec_evaluate(rec_rank, label) + # print(f"{perf_counter() - start}") + + def conv_evaluate(self, prediction, response, batch_user_id=None, batch_conv_id=None): + prediction = prediction.tolist() + response = response.tolist() + if batch_user_id is None: + for p, r in zip(prediction, response): + p_str = ind2txt(p, self.ind2tok, self.end_token_idx) + r_str = ind2txt(r, self.ind2tok, self.end_token_idx) + self.evaluator.gen_evaluate(p_str, [r_str], p) + else: + for p, r, uid, cid in zip(prediction, response, batch_user_id, batch_conv_id): + p_str = ind2txt(p, self.ind2tok, self.end_token_idx) + r_str = ind2txt(r, self.ind2tok, self.end_token_idx) + self.evaluator.gen_evaluate(p_str, [r_str], p) + + def step(self, batch, stage, mode): + assert stage in ('rec', 'conv') + assert mode in ('train', 'valid', 'test') + + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(self.device) + + if stage == 'rec': + rec_loss, rec_scores = self.model.forward(batch, mode, stage) + rec_loss = rec_loss.sum() + if mode == 'train': + self.backward(rec_loss) + else: + self.rec_evaluate(rec_scores, batch['item']) + rec_loss = rec_loss.item() + self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) + else: + if mode != 'test': + gen_loss, preds = self.model.forward(batch, mode, stage) + if mode == 'train': + self.backward(gen_loss) + else: + self.conv_evaluate(preds, batch['response']) + gen_loss = gen_loss.item() + self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss)) + self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) + else: + preds = self.model.forward(batch, mode, stage) + self.conv_evaluate(preds, batch['response'], batch.get('user_id', None), batch['conv_id']) + + def train_recommender(self): + self.init_optim(self.rec_optim_opt, self.model.parameters()) + + for epoch in range(self.rec_epoch): + self.evaluator.reset_metrics() + logger.info(f'[Recommendation epoch {str(epoch)}]') + logger.info('[Train]') + for batch in self.train_dataloader.get_rec_data(self.rec_batch_size): + self.step(batch, stage='rec', mode='train') + self.evaluator.report(epoch=epoch, mode='train') + # val + logger.info('[Valid]') + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + self.step(batch, stage='rec', mode='valid') + self.evaluator.report(epoch=epoch, mode='valid') + # early stop + metric = self.evaluator.optim_metrics['rec_loss'] + if self.early_stop(metric): + break + # test + logger.info('[Test]') + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False): + self.step(batch, stage='rec', mode='test') + self.evaluator.report(mode='test') + + def train_conversation(self): + if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': + self.model.freeze_parameters() + else: + self.model.module.freeze_parameters() + self.init_optim(self.conv_optim_opt, self.model.parameters()) + + for epoch in range(self.conv_epoch): + self.evaluator.reset_metrics() + logger.info(f'[Conversation epoch {str(epoch)}]') + logger.info('[Train]') + for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size): + self.step(batch, stage='conv', mode='train') + self.evaluator.report(epoch=epoch, mode='train') + # val + logger.info('[Valid]') + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + self.step(batch, stage='conv', mode='valid') + self.evaluator.report(epoch=epoch, mode='valid') + # early stop + metric = self.evaluator.optim_metrics['gen_loss'] + if self.early_stop(metric): + break + # test + logger.info('[Test]') + with torch.no_grad(): + self.evaluator.reset_metrics() + for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False): + self.step(batch, stage='conv', mode='test') + self.evaluator.report(mode='test') + + def fit(self): + self.train_recommender() + self.train_conversation() + + def interact(self): + pass diff --git a/HiCore/crslab/system/utils/__init__.py b/HiCore/crslab/system/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HiCore/crslab/system/utils/functions.py b/HiCore/crslab/system/utils/functions.py new file mode 100644 index 0000000..a622f36 --- /dev/null +++ b/HiCore/crslab/system/utils/functions.py @@ -0,0 +1,66 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2020/12/18 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +# UPDATE: +# @Time : 2021/10/05 +# @Author : Zhipeng Zhao +# @email : oran_official@outlook.com + +import torch + + +def compute_grad_norm(parameters, norm_type=2.0): + """ + Compute norm over gradients of model parameters. + + :param parameters: + the model parameters for gradient norm calculation. Iterable of + Tensors or single Tensor + :param norm_type: + type of p-norm to use + + :returns: + the computed gradient norm + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p is not None and p.grad is not None] + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + return total_norm ** (1.0 / norm_type) + + +def ind2txt(inds, ind2tok, end_token_idx=None, unk_token='unk'): + sentence = [] + for ind in inds: + if isinstance(ind, torch.Tensor): + ind = ind.item() + if end_token_idx and ind == end_token_idx: + break + sentence.append(ind2tok.get(ind, unk_token)) + return ' '.join(sentence) + +def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',slot_token='[ITEM]'): + sentence = [] + for ind in inds: + if isinstance(ind, torch.Tensor): + ind = ind.item() + if end_token_idx and ind == end_token_idx: + break + token = ind2tok.get(ind, unk_token) + if token == slot_token: + token = slots[0] + slots = slots[1:] + sentence.append(token) + return ' '.join(sentence) + +def ind2slot(inds,ind2slot): + return [ ind2slot[ind] for ind in inds] diff --git a/HiCore/crslab/system/utils/lr_scheduler.py b/HiCore/crslab/system/utils/lr_scheduler.py new file mode 100644 index 0000000..43c0a6d --- /dev/null +++ b/HiCore/crslab/system/utils/lr_scheduler.py @@ -0,0 +1,316 @@ +# @Time : 2020/12/1 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com + +from abc import abstractmethod, ABC + +# UPDATE: +# @Time : 2020/12/14 +# @Author : Xiaolei Wang +# @Email : wxl1999@foxmail.com +import math +import numpy as np +import torch +from loguru import logger +from torch import optim + + +class LRScheduler(ABC): + """ + Class for LR Schedulers. + + Includes some basic functionality by default - setting up the warmup + scheduler, passing the correct number of steps to train_step, loading and + saving states. + Subclasses must implement abstract methods train_step() and valid_step(). + Schedulers should be initialized with lr_scheduler_factory(). + __init__() should not be called directly. + """ + + def __init__(self, optimizer, warmup_steps: int = 0): + """ + Initialize warmup scheduler. Specific main schedulers should be initialized in + the subclasses. Do not invoke this method diretly. + + :param optimizer optimizer: + Optimizer being used for training. May be wrapped in + fp16_optimizer_wrapper depending on whether fp16 is used. + :param int warmup_steps: + Number of training step updates warmup scheduler should take. + """ + self._number_training_updates = 0 + self.warmup_steps = warmup_steps + self._init_warmup_scheduler(optimizer) + + def _warmup_lr(self, step): + """ + Return lr multiplier (on initial lr) for warmup scheduler. + """ + return float(step) / float(max(1, self.warmup_steps)) + + def _init_warmup_scheduler(self, optimizer): + if self.warmup_steps > 0: + self.warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._warmup_lr) + else: + self.warmup_scheduler = None + + def _is_lr_warming_up(self): + """ + Check if we're warming up the learning rate. + """ + return ( + hasattr(self, 'warmup_scheduler') + and self.warmup_scheduler is not None + and self._number_training_updates <= self.warmup_steps + ) + + def train_step(self): + """ + Use the number of train steps to adjust the warmup scheduler or the main + scheduler, depending on where in training we are. + + Override this method to override the behavior for training schedulers. + """ + self._number_training_updates += 1 + if self._is_lr_warming_up(): + self.warmup_scheduler.step() + else: + self.train_adjust() + + def valid_step(self, metric=None): + if self._is_lr_warming_up(): + # we're not done warming up, so don't start using validation + # metrics to adjust schedule + return + self.valid_adjust(metric) + + @abstractmethod + def train_adjust(self): + """ + Use the number of train steps to decide when to adjust LR schedule. + + Override this method to override the behavior for training schedulers. + """ + pass + + @abstractmethod + def valid_adjust(self, metric): + """ + Use the metrics to decide when to adjust LR schedule. + + This uses the loss as the validation metric if present, if not this + function does nothing. Note that the model must be reporting loss for + this to work. + + Override this method to override the behavior for validation schedulers. + """ + pass + + +class ReduceLROnPlateau(LRScheduler): + """ + Scheduler that decays by a multiplicative rate when valid loss plateaus. + """ + + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, + threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, warmup_steps=0): + super(ReduceLROnPlateau, self).__init__(optimizer, warmup_steps) + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor, + patience=patience, threshold=threshold, + threshold_mode=threshold_mode, cooldown=cooldown, + min_lr=min_lr, eps=eps) + + def train_adjust(self): + pass + + def valid_adjust(self, metric): + self.scheduler.step(metric) + + +class StepLR(LRScheduler): + """ + Scheduler that decays by a fixed multiplicative rate at each valid step. + """ + + def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, warmup_steps=0): + super(StepLR, self).__init__(optimizer, warmup_steps) + self.scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma, last_epoch) + + def train_adjust(self): + pass + + def valid_adjust(self, metric=None): + self.scheduler.step() + + +class ConstantLR(LRScheduler): + def __init__(self, optimizer, warmup_steps=0): + super(ConstantLR, self).__init__(optimizer, warmup_steps) + + def train_adjust(self): + pass + + def valid_adjust(self, metric): + pass + + +class InvSqrtLR(LRScheduler): + """ + Scheduler that decays at an inverse square root rate. + """ + + def __init__(self, optimizer, invsqrt_lr_decay_gamma=-1, last_epoch=-1, warmup_steps=0): + """ + invsqrt_lr_decay_gamma determines the cycle length of the inverse square root + scheduler. + + When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1 + """ + super(InvSqrtLR, self).__init__(optimizer, warmup_steps) + self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma + if invsqrt_lr_decay_gamma <= 0: + logger.warning( + '--lr-scheduler invsqrt requires a value for ' + '--invsqrt-lr-decay-gamma. Defaulting to set gamma to ' + '--warmup-updates value for backwards compatibility.' + ) + self.invsqrt_lr_decay_gamma = self.warmup_steps + + self.decay_factor = np.sqrt(max(1, self.invsqrt_lr_decay_gamma)) + self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._invsqrt_lr, last_epoch) + + def _invsqrt_lr(self, step): + return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step)) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + # this is a training step lr scheduler, nothing to adjust in validation + pass + + +class CosineAnnealingLR(LRScheduler): + """ + Scheduler that decays by a cosine function. + """ + + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, warmup_steps=0): + """ + training_steps determines the cycle length of the cosine annealing. + + It indicates the number of steps from 1.0 multiplier to 0.0, which corresponds + to going from cos(0) to cos(pi) + """ + super(CosineAnnealingLR, self).__init__(optimizer, warmup_steps) + self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min, last_epoch) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass + + +class CosineAnnealingWarmRestartsLR(LRScheduler): + def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warmup_steps=0): + super(CosineAnnealingWarmRestartsLR, self).__init__(optimizer, warmup_steps) + self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult, eta_min, last_epoch) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass + + +class TransformersLinearLR(LRScheduler): + """ + Scheduler that decays linearly. + """ + + def __init__(self, optimizer, training_steps, warmup_steps=0): + """ + training_steps determines the cycle length of the linear annealing. + + It indicates the number of steps from 1.0 multiplier to 0.0 + """ + super(TransformersLinearLR, self).__init__(optimizer, warmup_steps) + self.training_steps = training_steps - warmup_steps + self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr) + + def _linear_lr(self, step): + return max(0.0, float(self.training_steps - step) / float(max(1, self.training_steps))) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass + + +class TransformersCosineLR(LRScheduler): + def __init__(self, optimizer, training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, + warmup_steps: int = 0): + super(TransformersCosineLR, self).__init__(optimizer, warmup_steps) + self.training_steps = training_steps - warmup_steps + self.num_cycles = num_cycles + self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr, last_epoch) + + def _cosine_lr(self, step): + progress = float(step) / float(max(1, self.training_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress))) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass + + +class TransformersCosineWithHardRestartsLR(LRScheduler): + def __init__(self, optimizer, training_steps: int, num_cycles: int = 1, last_epoch: int = -1, + warmup_steps: int = 0): + super(TransformersCosineWithHardRestartsLR, self).__init__(optimizer, warmup_steps) + self.training_steps = training_steps - warmup_steps + self.num_cycles = num_cycles + self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_with_hard_restarts_lr, last_epoch) + + def _cosine_with_hard_restarts_lr(self, step): + progress = float(step) / float(max(1, self.training_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(self.num_cycles) * progress) % 1.0)))) + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass + + +class TransformersPolynomialDecayLR(LRScheduler): + def __init__(self, optimizer, training_steps, lr_end=1e-7, power=1.0, last_epoch=-1, warmup_steps=0): + super(TransformersPolynomialDecayLR, self).__init__(optimizer, warmup_steps) + self.training_steps = training_steps - warmup_steps + self.lr_init = optimizer.defaults["lr"] + self.lr_end = lr_end + assert self.lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({self.lr_init})" + self.power = power + self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._polynomial_decay_lr, last_epoch) + + def _polynomial_decay_lr(self, step): + if step > self.training_steps: + return self.lr_end / self.lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = self.lr_init - self.lr_end + decay_steps = self.training_steps + pct_remaining = 1 - step / decay_steps + decay = lr_range * pct_remaining ** self.power + self.lr_end + return decay / self.lr_init # as LambdaLR multiplies by lr_init + + def train_adjust(self): + self.scheduler.step() + + def valid_adjust(self, metric): + pass diff --git a/HiCore/run_crslab.py b/HiCore/run_crslab.py new file mode 100644 index 0000000..2773bfb --- /dev/null +++ b/HiCore/run_crslab.py @@ -0,0 +1,45 @@ +# @Time : 2020/11/22 +# @Author : Kun Zhou +# @Email : francis_kun_zhou@163.com + +# UPDATE: +# @Time : 2020/11/24, 2021/1/9 +# @Author : Kun Zhou, Xiaolei Wang +# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com + +import argparse +import warnings + +from crslab.config import Config + +warnings.filterwarnings('ignore') + +if __name__ == '__main__': + # parse args + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, + default='config/crs/hicore/hredial.yaml', help='config file(yaml) path') + parser.add_argument('-g', '--gpu', type=str, default='-1', + help='specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).') + parser.add_argument('-sd', '--save_data', action='store_true', + help='save processed dataset') + parser.add_argument('-rd', '--restore_data', action='store_true', + help='restore processed dataset') + parser.add_argument('-ss', '--save_system', action='store_true', + help='save trained system') + parser.add_argument('-rs', '--restore_system', action='store_true', + help='restore trained system') + parser.add_argument('-d', '--debug', action='store_true', + help='use valid dataset to debug your system') + parser.add_argument('-i', '--interact', action='store_true', + help='interact with your system instead of training') + parser.add_argument('-s', '--seed', type=int, default=2020) + parser.add_argument('-p', '--pretrain', action='store_true', help='use pretrain weights') + parser.add_argument('-e', '--pretrain_epoch', type=int, default=9999, help='pretrain epoch') + args, _ = parser.parse_known_args() + config = Config(args.config, args.gpu, args.debug, args.seed, args.pretrain, args.pretrain_epoch) + + from crslab.quick_start import run_crslab + + run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact, + args.debug) diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f51b477 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "hicore" +version = "0.1.0" +description = "The implementation of Mitigating Matthew Effect: Multi-Hypergraph Boosted Multi-Interest Self-Supervised Learning for Conversational Recommendation (EMNLP 2024)" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "argparse>=1.4.0", + "dgl>=2.2.1", + "loguru>=0.7.3", + "nltk>=3.9.1", + "pyyaml>=6.0.2", + "requests>=2.32.4", + "scikit-learn>=1.7.1", + "torch>=2.8.0", + "torch-geometric>=2.6.1", + "tqdm>=4.67.1", + "transformers>=4.55.0", +] + +[[tool.uv.index]] +url="https://pypi.tuna.tsinghua.edu.cn/simple" +default=true