init HyCoRec

This commit is contained in:
Tokisakix
2025-08-07 00:34:22 +08:00
commit 48ee240418
67 changed files with 6985 additions and 0 deletions

14
.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
*.pyc
*.json
*.pkl
*.txt
*.log
.built
*.npy
*.npz
res/
log/
.python-version
uv.lock

View File

@@ -0,0 +1,50 @@
# dataset
dataset: DuRecDial
tokenize: jieba
# dataloader
related_truncate: 1024
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: HyCoRec
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# HyCoRec-CHANGE
mha_n_heads: 4
pooling: Attn
extension_strategy: Adaptive
# optim
rec:
epoch: 1
batch_size: 64
early_stop: False
stop_mode: min
impatience: 2
optimizer:
name: Adam
lr: !!float 1e-3
conv:
epoch: 1
batch_size: 128
impatience: 1
optimizer:
name: Adam
lr: !!float 1e-3
lr_scheduler:
name: ReduceLROnPlateau
patience: 3
factor: 0.5
gradient_clip: 0.1

View File

@@ -0,0 +1,50 @@
# dataset
dataset: OpenDialKG
tokenize: nltk
# dataloader
related_truncate: 1024
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: HyCoRec
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# HyCoRec-CHANGE
mha_n_heads: 4
pooling: Mean
extension_strategy: Adaptive
# optim
rec:
epoch: 1
batch_size: 256
early_stop: False
stop_mode: min
impatience: 2
optimizer:
name: Adam
lr: !!float 1e-3
conv:
epoch: 1
batch_size: 128
impatience: 1
optimizer:
name: Adam
lr: !!float 1e-3
lr_scheduler:
name: ReduceLROnPlateau
patience: 3
factor: 0.5
gradient_clip: 0.1

View File

@@ -0,0 +1,50 @@
# dataset
dataset: ReDial
tokenize: nltk
# dataloader
related_truncate: 1024
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: HyCoRec
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# HyCoRec-CHANGE
mha_n_heads: 4
pooling: Mean
extension_strategy: Adaptive
# optim
rec:
epoch: 1
batch_size: 256
early_stop: False
stop_mode: min
impatience: 2
optimizer:
name: Adam
lr: !!float 1e-3
conv:
epoch: 1
batch_size: 128
impatience: 1
optimizer:
name: Adam
lr: !!float 1e-3
lr_scheduler:
name: ReduceLROnPlateau
patience: 3
factor: 0.5
gradient_clip: 0.1

View File

@@ -0,0 +1,50 @@
# dataset
dataset: TGReDial
tokenize: pkuseg
# dataloader
related_truncate: 1024
context_truncate: 256
response_truncate: 30
scale: 1
# model
model: HyCoRec
token_emb_dim: 300
kg_emb_dim: 128
num_bases: 8
n_heads: 2
n_layers: 2
ffn_size: 300
dropout: 0.1
attention_dropout: 0.0
relu_dropout: 0.1
learn_positional_embeddings: false
embeddings_scale: true
reduction: false
n_positions: 1024
user_proj_dim: 512
# HyCoRec-CHANGE
mha_n_heads: 4
pooling: Attn
extension_strategy: Adaptive
# optim
rec:
epoch: 1
batch_size: 64
early_stop: False
stop_mode: min
impatience: 2
optimizer:
name: Adam
lr: !!float 1e-3
conv:
epoch: 1
batch_size: 128
impatience: 1
optimizer:
name: Adam
lr: !!float 1e-3
lr_scheduler:
name: ReduceLROnPlateau
patience: 3
factor: 0.5
gradient_clip: 0.1

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,32 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/29
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
"""Config module which loads parameters for the whole system.
Attributes:
SAVE_PATH (str): where system to save.
DATASET_PATH (str): where dataset to save.
MODEL_PATH (str): where model related data to save.
PRETRAIN_PATH (str): where pretrained model to save.
EMBEDDING_PATH (str): where pretrained embedding to save, used for evaluate embedding related metrics.
"""
import os
from os.path import dirname, realpath
from .config import Config
ROOT_PATH = dirname(dirname(dirname(realpath(__file__))))
SAVE_PATH = os.path.join(ROOT_PATH, 'save')
DATA_PATH = os.path.join(ROOT_PATH, 'data')
DATASET_PATH = os.path.join(DATA_PATH, 'dataset')
MODEL_PATH = os.path.join(DATA_PATH, 'model')
PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain')
EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding')

View File

@@ -0,0 +1,155 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2021/1/9
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
import json
import os
import time
from pprint import pprint
import yaml
from loguru import logger
from tqdm import tqdm
import random
import numpy as np
import torch
class Config:
"""Configurator module that load the defined parameters."""
def __init__(self, config_file, gpu='-1', debug=False, seed=2020, pretrain=False, pretrain_epoch=None):
"""Load parameters and set log level.
Args:
config_file (str): path to the config file, which should be in ``yaml`` format.
You can use default config provided in the `Github repo`_, or write it by yourself.
debug (bool, optional): whether to enable debug function during running. Defaults to False.
.. _Github repo:
https://github.com/RUCAIBox/CRSLab
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.opt = self.load_yaml_configs(config_file)
# pretrain epoch
self.opt['pretrain'] = pretrain
self.opt['pretrain_epoch'] = pretrain_epoch
# gpu
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] if gpu != '-1' else [-1]
# dataset
dataset = self.opt['dataset']
tokenize = self.opt['tokenize']
if isinstance(tokenize, dict):
tokenize = ', '.join(tokenize.values())
# model
model = self.opt.get('model', None)
rec_model = self.opt.get('rec_model', None)
conv_model = self.opt.get('conv_model', None)
policy_model = self.opt.get('policy_model', None)
if model:
model_name = model
else:
models = []
if rec_model:
models.append(rec_model)
if conv_model:
models.append(conv_model)
if policy_model:
models.append(policy_model)
model_name = '_'.join(models)
self.opt['model_name'] = model_name
self.opt['rankfile'] = f"{model_name}-{dataset}.json"
# log
log_name = self.opt.get("log_name", dataset + '_' + model_name + '_' + time.strftime("%Y-%m-%d-%H-%M-%S",
time.localtime())) + ".log"
if not os.path.exists("log"):
os.makedirs("log")
logger.remove()
if debug:
level = 'DEBUG'
else:
level = 'INFO'
logger.add(os.path.join("log", log_name), level=level)
logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level)
logger.info(f"[Dataset: {dataset} tokenized in {tokenize}]")
if model:
logger.info(f'[Model: {model}]')
if rec_model:
logger.info(f'[Recommendation Model: {rec_model}]')
if conv_model:
logger.info(f'[Conversation Model: {conv_model}]')
if policy_model:
logger.info(f'[Policy Model: {policy_model}]')
logger.info("[Config]" + '/n' + json.dumps(self.opt, indent=4))
@staticmethod
def load_yaml_configs(filename):
"""This function reads ``yaml`` file to build config dictionary
Args:
filename (str): path to ``yaml`` config
Returns:
dict: config
"""
config_dict = dict()
with open(filename, 'r', encoding='utf-8') as f:
config_dict.update(yaml.safe_load(f.read()))
return config_dict
def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
self.opt[key] = value
def __getitem__(self, item):
if item in self.opt:
return self.opt[item]
else:
return None
def get(self, item, default=None):
"""Get value of corrsponding item in config
Args:
item (str): key to query in config
default (optional): default value for item if not found in config. Defaults to None.
Returns:
value of corrsponding item in config
"""
if item in self.opt:
return self.opt[item]
else:
return default
def __contains__(self, key):
if not isinstance(key, str):
raise TypeError("index must be a str.")
return key in self.opt
def __str__(self):
return str(self.opt)
def __repr__(self):
return self.__str__()
if __name__ == '__main__':
opt_dict = Config('config/crs/hycorec/hredial.yaml')
pprint(opt_dict)

View File

@@ -0,0 +1,85 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/29, 2020/12/17
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com
# @Time : 2021/10/06
# @Author : Zhipeng Zhao
# @Email : oran_official@outlook.com
"""Data module which reads, processes and batches data for the whole system
Attributes:
dataset_register_table (dict): record all supported dataset
dataset_language_map (dict): record all dataset corresponding language
dataloader_register_table (dict): record all model corresponding dataloader
"""
from crslab.data.dataloader import *
from crslab.data.dataset import *
dataset_register_table = {
'HReDial': HReDialDataset,
'HTGReDial': HTGReDialDataset,
'OpenDialKG': OpenDialKGDataset,
'DuRecDial': DuRecDialDataset,
'ReDial': ReDialDataset,
'TGReDial': TGReDialDataset,
}
dataset_language_map = {
'ReDial': 'en',
'TGReDial': 'zh',
'HReDial': 'en',
'HTGReDial': 'zh',
'OpenDialKG': 'en',
'DuRecDial': 'zh',
}
dataloader_register_table = {
'HyCoRec': HyCoRecDataLoader,
}
def get_dataset(opt, tokenize, restore, save) -> BaseDataset:
"""get and process dataset
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize the dataset.
restore (bool): whether to restore saved dataset which has been processed.
save (bool): whether to save dataset after processing.
Returns:
processed dataset
"""
dataset = opt['dataset']
if dataset in dataset_register_table:
return dataset_register_table[dataset](opt, tokenize, restore, save)
else:
raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented')
def get_dataloader(opt, dataset, vocab) -> BaseDataLoader:
"""get dataloader to batchify dataset
Args:
opt (Config or dict): config for dataloader or the whole system.
dataset: processed raw data, no side data.
vocab (dict): all kinds of useful size, idx and map between token and idx.
Returns:
dataloader
"""
model_name = opt['model_name']
if model_name in dataloader_register_table:
return dataloader_register_table[model_name](opt, dataset, vocab)
else:
raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented')

View File

@@ -0,0 +1,2 @@
from .base import BaseDataLoader
from .hycorec import HyCoRecDataLoader

View File

@@ -0,0 +1,211 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2020/12/29
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
import random
from abc import ABC
from loguru import logger
from math import ceil
from tqdm import tqdm
class BaseDataLoader(ABC):
"""Abstract class of dataloader
Notes:
``'scale'`` can be set in config to limit the size of dataset.
"""
def __init__(self, opt, dataset):
"""
Args:
opt (Config or dict): config for dataloader or the whole system.
dataset: dataset
"""
self.opt = opt
self.dataset = dataset
self.scale = opt.get('scale', 1)
assert 0 < self.scale <= 1
def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
"""Collate batch data for system to fit
Args:
batch_fn (func): function to collate data
batch_size (int):
shuffle (bool, optional): Defaults to True.
process_fn (func, optional): function to process dataset before batchify. Defaults to None.
Yields:
tuple or dict of torch.Tensor: batch data for system to fit
"""
dataset = self.dataset
if process_fn is not None:
dataset = process_fn()
# logger.info('[Finish dataset process before batchify]')
dataset = dataset[:ceil(len(dataset) * self.scale)]
logger.debug(f'[Dataset size: {len(dataset)}]')
batch_num = ceil(len(dataset) / batch_size)
idx_list = list(range(len(dataset)))
if shuffle:
random.shuffle(idx_list)
for start_idx in tqdm(range(batch_num)):
batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size]
batch = [dataset[idx] for idx in batch_idx]
batch = batch_fn(batch)
if batch == False:
continue
else:
# print(batch)
yield(batch)
def get_conv_data(self, batch_size, shuffle=True):
"""get_data wrapper for conversation.
You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``.
Args:
batch_size (int):
shuffle (bool, optional): Defaults to True.
Yields:
tuple or dict of torch.Tensor: batch data for conversation.
"""
return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn)
def get_rec_data(self, batch_size, shuffle=True):
"""get_data wrapper for recommendation.
You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``.
Args:
batch_size (int):
shuffle (bool, optional): Defaults to True.
Yields:
tuple or dict of torch.Tensor: batch data for recommendation.
"""
return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn)
def get_policy_data(self, batch_size, shuffle=True):
"""get_data wrapper for policy.
You can implement your own process_fn in ``self.policy_process_fn``, batch_fn in ``policy_batchify``.
Args:
batch_size (int):
shuffle (bool, optional): Defaults to True.
Yields:
tuple or dict of torch.Tensor: batch data for policy.
"""
return self.get_data(self.policy_batchify, batch_size, shuffle, self.policy_process_fn)
def conv_process_fn(self):
"""Process whole data for conversation before batch_fn.
Returns:
processed dataset. Defaults to return the same as `self.dataset`.
"""
return self.dataset
def conv_batchify(self, batch):
"""batchify data for conversation after process.
Args:
batch (list): processed batch dataset.
Returns:
batch data for the system to train conversation part.
"""
raise NotImplementedError('dataloader must implement conv_batchify() method')
def rec_process_fn(self):
"""Process whole data for recommendation before batch_fn.
Returns:
processed dataset. Defaults to return the same as `self.dataset`.
"""
return self.dataset
def rec_batchify(self, batch):
"""batchify data for recommendation after process.
Args:
batch (list): processed batch dataset.
Returns:
batch data for the system to train recommendation part.
"""
raise NotImplementedError('dataloader must implement rec_batchify() method')
def policy_process_fn(self):
"""Process whole data for policy before batch_fn.
Returns:
processed dataset. Defaults to return the same as `self.dataset`.
"""
return self.dataset
def policy_batchify(self, batch):
"""batchify data for policy after process.
Args:
batch (list): processed batch dataset.
Returns:
batch data for the system to train policy part.
"""
raise NotImplementedError('dataloader must implement policy_batchify() method')
def retain_recommender_target(self):
"""keep data whose role is recommender.
Returns:
Recommender part of ``self.dataset``.
"""
dataset = []
for conv_dict in tqdm(self.dataset):
if conv_dict['role'] == 'Recommender':
dataset.append(conv_dict)
return dataset
def rec_interact(self, data):
"""process user input data for system to recommend.
Args:
data: user input data.
Returns:
data for system to recommend.
"""
pass
def conv_interact(self, data):
"""Process user input data for system to converse.
Args:
data: user input data.
Returns:
data for system in converse.
"""
pass

View File

@@ -0,0 +1,144 @@
# -*- encoding: utf-8 -*-
# @Time : 2021/5/26
# @Author : Chenzhan Shang
# @email : czshang@outlook.com
import pickle
import torch
from tqdm import tqdm
from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt
class HyCoRecDataLoader(BaseDataLoader):
"""Dataloader for model KBRD.
Notes:
You can set the following parameters in config:
- ``"context_truncate"``: the maximum length of context.
- ``"response_truncate"``: the maximum length of response.
- ``"entity_truncate"``: the maximum length of mentioned entities in context.
The following values must be specified in ``vocab``:
- ``"pad"``
- ``"start"``
- ``"end"``
- ``"pad_entity"``
the above values specify the id of needed special token.
"""
def __init__(self, opt, dataset, vocab):
"""
Args:
opt (Config or dict): config for dataloader or the whole system.
dataset: data for model.
vocab (dict): all kinds of useful size, idx and map between token and idx.
"""
super().__init__(opt, dataset)
self.pad_token_idx = vocab["tok2ind"]["__pad__"]
self.start_token_idx = vocab["tok2ind"]["__start__"]
self.end_token_idx = vocab["tok2ind"]["__end__"]
self.split_token_idx = vocab["tok2ind"].get("_split_", None)
self.related_truncate = opt.get("related_truncate", None)
self.context_truncate = opt.get("context_truncate", None)
self.response_truncate = opt.get("response_truncate", None)
self.entity_truncate = opt.get("entity_truncate", None)
self.review_entity2id = vocab["entity2id"]
return
def rec_process_fn(self):
augment_dataset = []
for conv_dict in tqdm(self.dataset):
if conv_dict["role"] == "Recommender":
for item in conv_dict["items"]:
augment_conv_dict = {
"conv_id": conv_dict["conv_id"],
"related_item": conv_dict["item"],
"related_entity": conv_dict["entity"],
"related_word": conv_dict["word"],
"item": item,
}
augment_dataset.append(augment_conv_dict)
return augment_dataset
def rec_batchify(self, batch):
batch_related_item = []
batch_related_entity = []
batch_related_word = []
batch_movies = []
batch_conv_id = []
for conv_dict in batch:
batch_related_item.append(conv_dict["related_item"])
batch_related_entity.append(conv_dict["related_entity"])
batch_related_word.append(conv_dict["related_word"])
batch_movies.append(conv_dict["item"])
batch_conv_id.append(conv_dict["conv_id"])
res = {
"conv_id": batch_conv_id,
"related_item": batch_related_item,
"related_entity": batch_related_entity,
"related_word": batch_related_word,
"item": torch.tensor(batch_movies, dtype=torch.long),
}
return res
def conv_process_fn(self, *args, **kwargs):
return self.retain_recommender_target()
def conv_batchify(self, batch):
batch_related_tokens = []
batch_context_tokens = []
batch_related_item = []
batch_related_entity = []
batch_related_word = []
batch_response = []
batch_conv_id = []
for conv_dict in batch:
batch_related_tokens.append(
truncate(conv_dict["tokens"][-1], self.related_truncate, truncate_tail=False)
)
batch_context_tokens.append(
truncate(merge_utt(
conv_dict["tokens"],
start_token_idx=self.start_token_idx,
split_token_idx=self.split_token_idx,
final_token_idx=self.end_token_idx
), self.context_truncate, truncate_tail=False)
)
batch_related_item.append(conv_dict["item"])
batch_related_entity.append(conv_dict["entity"])
batch_related_word.append(conv_dict["word"])
batch_response.append(
add_start_end_token_idx(truncate(conv_dict["response"], self.response_truncate - 2),
start_token_idx=self.start_token_idx,
end_token_idx=self.end_token_idx))
batch_conv_id.append(conv_dict["conv_id"])
res = {
"related_tokens": padded_tensor(batch_related_tokens, self.pad_token_idx, pad_tail=False),
"context_tokens": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),
"related_item": batch_related_item,
"related_entity": batch_related_entity,
"related_word": batch_related_word,
"response": padded_tensor(batch_response, self.pad_token_idx),
"conv_id": batch_conv_id,
}
return res
def policy_batchify(self, *args, **kwargs):
pass

View File

@@ -0,0 +1,182 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/10
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/20, 2020/12/15
# @Author : Xiaolei Wang, Yuanhang Zhou
# @email : wxl1999@foxmail.com, sdzyh002@gmail
# UPDATE
# @Time : 2021/10/06
# @Author : Zhipeng Zhao
# @Email : oran_official@outlook.com
from copy import copy
import torch
from typing import List, Union, Optional
def padded_tensor(
items: List[Union[List[int], torch.LongTensor]],
pad_idx: int = 0,
pad_tail: bool = True,
max_len: Optional[int] = None,
) -> torch.LongTensor:
"""Create a padded matrix from an uneven list of lists.
Returns padded matrix.
Matrix is right-padded (filled to the right) by default, but can be
left padded if the flag is set to True.
Matrix can also be placed on cuda automatically.
:param list[iter[int]] items: List of items
:param int pad_idx: the value to use for padding
:param bool pad_tail:
:param int max_len: if None, the max length is the maximum item length
:returns: padded tensor.
:rtype: Tensor[int64]
"""
# number of items
n = len(items)
# length of each item
lens: List[int] = [len(item) for item in items] # type: ignore
# max in time dimension
t = max(lens) if max_len is None else max_len
# if input tensors are empty, we should expand to nulls
t = max(t, 1)
if isinstance(items[0], torch.Tensor):
# keep type of input tensors, they may already be cuda ones
output = items[0].new(n, t) # type: ignore
else:
output = torch.LongTensor(n, t) # type: ignore
output.fill_(pad_idx)
for i, (item, length) in enumerate(zip(items, lens)):
if length == 0:
# skip empty items
continue
if not isinstance(item, torch.Tensor):
# put non-tensors into a tensor
item = torch.tensor(item, dtype=torch.long) # type: ignore
if pad_tail:
# place at beginning
output[i, :length] = item
else:
# place at end
output[i, t - length:] = item
return output
def get_onehot(data_list, categories) -> torch.Tensor:
"""Transform lists of label into one-hot.
Args:
data_list (list of list of int): source data.
categories (int): #label class.
Returns:
torch.Tensor: one-hot labels.
"""
onehot_labels = []
for label_list in data_list:
onehot_label = torch.zeros(categories)
for label in label_list:
onehot_label[label] = 1.0 / len(label_list)
onehot_labels.append(onehot_label)
return torch.stack(onehot_labels, dim=0)
def add_start_end_token_idx(vec: list, start_token_idx: int = None, end_token_idx: int = None):
"""Can choose to add start token in the beginning and end token in the end.
Args:
vec: source list composed of indexes.
start_token_idx: index of start token.
end_token_idx: index of end token.
Returns:
list: list added start or end token index.
"""
res = copy(vec)
if start_token_idx:
res.insert(0, start_token_idx)
if end_token_idx:
res.append(end_token_idx)
return res
def truncate(vec, max_length, truncate_tail=True):
"""truncate vec to make its length no more than max length.
Args:
vec (list): source list.
max_length (int)
truncate_tail (bool, optional): Defaults to True.
Returns:
list: truncated vec.
"""
if max_length is None:
return vec
if len(vec) <= max_length:
return vec
if max_length == 0:
return []
if truncate_tail:
return vec[:max_length]
else:
return vec[-max_length:]
def merge_utt(conversation, start_token_idx=None, split_token_idx=None, keep_split_in_tail=False, final_token_idx=None):
"""merge utterances in one conversation.
Args:
conversation (list of list of int): conversation consist of utterances consist of tokens.
split_token_idx (int): index of split token. Defaults to None.
keep_split_in_tail (bool): split in tail or head. Defaults to False.
final_token_idx (int): index of final token. Defaults to None.
Returns:
list: tokens of all utterances in one list.
"""
merged_conv = []
if start_token_idx:
merged_conv.append(start_token_idx)
for utt in conversation:
for token in utt:
merged_conv.append(token)
if split_token_idx:
merged_conv.append(split_token_idx)
if split_token_idx and not keep_split_in_tail:
merged_conv = merged_conv[:-1]
if final_token_idx:
merged_conv.append(final_token_idx)
return merged_conv
def merge_utt_replace(conversation,detect_token=None,replace_token=None,method="in"):
if method == 'in':
replaced_conv = []
for utt in conversation:
for token in utt:
if detect_token in token:
replaced_conv.append(replace_token)
else:
replaced_conv.append(token)
return replaced_conv
else:
return [token.replace(detect_token,replace_token) for utt in conversation for token in utt]

View File

@@ -0,0 +1,8 @@
from .base import BaseDataset
from .hredial import HReDialDataset
from .htgredial import HTGReDialDataset
from .opendialkg import OpenDialKGDataset
from .durecdial import DuRecDialDataset
from .redial import ReDialDataset
from .tgredial import TGReDialDataset

View File

@@ -0,0 +1,171 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2020/12/13
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
import os
import pickle as pkl
from abc import ABC, abstractmethod
import numpy as np
from loguru import logger
from crslab.download import build
class BaseDataset(ABC):
"""Abstract class of dataset
Notes:
``'embedding'`` can be specified in config to use pretrained word embedding.
"""
def __init__(self, opt, dpath, resource, restore=False, save=False):
"""Download resource, load, process data. Support restore and save processed dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
dpath (str): where to store dataset.
resource (dict): version, download file and special token idx of tokenized dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
self.opt = opt
self.dpath = dpath
# download
dfile = resource['file']
build(dpath, dfile, version=resource['version'])
if not restore:
# load and process
train_data, valid_data, test_data, self.vocab = self._load_data()
logger.info('[Finish data load]')
self.train_data, self.valid_data, self.test_data, self.side_data = self._data_preprocess(train_data, valid_data, test_data)
embedding = opt.get('embedding', None)
if embedding:
self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding))
logger.debug(f'[Load pretrained embedding {embedding}]')
logger.info('[Finish data preprocess]')
else:
self._load_other_data()
self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab = self._load_from_restore()
if save:
data = (self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab)
self._save_to_one(data)
@abstractmethod
def _load_other_data(self):
"""
Load review.
"""
pass
@abstractmethod
def _load_data(self):
"""Load dataset.
Returns:
(any, any, any, dict):
raw train, valid and test data.
vocab: all kinds of useful size, idx and map between token and idx.
"""
pass
@abstractmethod
def _data_preprocess(self, train_data, valid_data, test_data):
"""Process raw train, valid, test data.
Args:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
Returns:
(list of dict, dict):
train/valid/test_data, each dict is in the following format::
{
'role' (str):
'Seeker' or 'Recommender',
'user_profile' (list of list of int):
id of tokens of sentences of user profile,
'context_tokens' (list of list int):
token ids of preprocessed contextual dialogs,
'response' (list of int):
token ids of the ground-truth response,
'interaction_history' (list of int):
id of items which have interaction of the user in current turn,
'context_items' (list of int):
item ids mentioned in context,
'items' (list of int):
item ids mentioned in current turn, we only keep
those in entity kg for comparison,
'context_entities' (list of int):
if necessary, id of entities in context,
'context_words' (list of int):
if necessary, id of words in context,
'context_policy' (list of list of list):
policy of each context turn, one turn may have several policies,
where first is action and second is keyword,
'target' (list): policy of current turn,
'final' (list): final goal for current turn
}
side_data, which is in the following format::
{
'entity_kg': {
'edge' (list of tuple): (head_entity_id, tail_entity_id, relation_id),
'n_relation' (int): number of distinct relations,
'entity' (list of str): str of entities, used for entity linking
}
'word_kg': {
'edge' (list of tuple): (head_entity_id, tail_entity_id),
'entity' (list of str): str of entities, used for entity linking
}
'item_entity_ids' (list of int): entity id of each item;
}
"""
pass
def _load_from_restore(self, file_name="all_data.pkl"):
"""Restore saved dataset.
Args:
file_name (str): file of saved dataset. Defaults to "all_data.pkl".
"""
if not os.path.exists(os.path.join(self.dpath, file_name)):
raise ValueError(f'Saved dataset [{file_name}] does not exist')
with open(os.path.join(self.dpath, file_name), 'rb') as f:
dataset = pkl.load(f)
logger.info(f'Restore dataset from [{file_name}]')
return dataset
def _save_to_one(self, data, file_name="all_data.pkl"):
"""Save all processed dataset and vocab into one file.
Args:
data (tuple): all dataset and vocab.
file_name (str, optional): file to save dataset. Defaults to "all_data.pkl".
"""
if not os.path.exists(self.dpath):
os.makedirs(self.dpath)
save_path = os.path.join(self.dpath, file_name)
with open(save_path, 'wb') as f:
pkl.dump(data, f)
logger.info(f'[Save dataset to {file_name}]')

View File

@@ -0,0 +1 @@
from .durecdial import DuRecDialDataset

View File

@@ -0,0 +1,281 @@
# @Time : 2020/12/21
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/12/21, 2021/1/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
r"""
DuRecDial
=========
References:
Liu, Zeming, et al. `"Towards Conversational Recommendation over Multi-Type Dialogs."`_ in ACL 2020.
.. _"Towards Conversational Recommendation over Multi-Type Dialogs.":
https://www.aclweb.org/anthology/2020.acl-main.98/
"""
import json
import os
from copy import copy
from loguru import logger
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class DuRecDialDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
self.special_token_idx = resource['special_token_idx']
self.unk_token_idx = self.special_token_idx['unk']
dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize)
self.train_review = None
self.valid_review = None
self.test_review = None
super().__init__(opt, dpath, resource, restore, save)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'word2id': self.word2id,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity,
'n_word': self.n_word,
}
vocab.update(self.special_token_idx)
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
def _load_other_data(self):
# entity kg
with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:
self.entity2id = json.load(f) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
# {head_entity_id: [(relation_id, tail_entity_id)]}
self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8')
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]")
# hownet
# {concept: concept_id}
with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:
self.word2id = json.load(f)
self.n_word = max(self.word2id.values()) + 1
# {concept \t relation\t concept}
self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8')
logger.debug(
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self._side_data_process()
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
def _convert_to_id(self, idx, conversation):
augmented_convs = []
last_role = None
conv_id = conversation.get("conv_id", idx)
related_item = []
related_entity = []
related_word = []
for utt in conversation['dialog']:
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
related_item += item_ids
related_entity += entity_ids
related_word += word_ids
if utt["role"] == last_role:
augmented_convs[-1]["text"] += text_token_ids
augmented_convs[-1]["item"] += item_ids
augmented_convs[-1]["entity"] += entity_ids
augmented_convs[-1]["word"] += word_ids
else:
augmented_convs.append({
"conv_id": conv_id,
"role": utt["role"],
"text": text_token_ids,
"item": related_item,
"entity": related_entity,
"word": related_word
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_items = [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
if len(context_tokens) > 0:
conv_dict = {
"conv_id": conv["conv_id"],
"role": conv["role"],
"tokens": copy(context_tokens),
"response": text_tokens,
"item": copy(context_items),
"entity": copy(context_entities),
"word": copy(context_words),
"items": movies
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts
def _side_data_process(self):
processed_entity_kg = self._entity_kg_process()
logger.debug("[Finish entity KG process]")
processed_word_kg = self._word_kg_process()
logger.debug("[Finish word KG process]")
with open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8') as f:
item_entity_ids = json.load(f)
logger.debug('[Load movie entity ids]')
side_data = {
"entity_kg": processed_entity_kg,
"word_kg": processed_word_kg,
"item_entity_ids": item_entity_ids,
}
return side_data
def _entity_kg_process(self):
edge_list = [] # [(entity, entity, relation)]
for line in self.entity_kg:
triple = line.strip().split('\t')
e0 = self.entity2id[triple[0]]
e1 = self.entity2id[triple[2]]
r = triple[1]
edge_list.append((e0, e1, r))
edge_list.append((e1, e0, r))
edge_list.append((e0, e0, 'SELF_LOOP'))
if e1 != e0:
edge_list.append((e1, e1, 'SELF_LOOP'))
relation2id, edges, entities = dict(), set(), set()
for h, t, r in edge_list:
if r not in relation2id:
relation2id[r] = len(relation2id)
edges.add((h, t, relation2id[r]))
entities.add(self.id2entity[h])
entities.add(self.id2entity[t])
return {
'edge': list(edges),
'n_relation': len(relation2id),
'entity': list(entities)
}
def _word_kg_process(self):
edges = set() # {(entity, entity)}
entities = set()
for line in self.word_kg:
triple = line.strip().split('\t')
entities.add(triple[0])
entities.add(triple[2])
e0 = self.word2id[triple[0]]
e1 = self.word2id[triple[2]]
edges.add((e0, e1))
edges.add((e1, e0))
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
return {
'edge': list(edges),
'entity': list(entities)
}

View File

@@ -0,0 +1,70 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'jieba': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1',
'durecdial_jieba.zip',
'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'bert': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1',
'durecdial_bert.zip',
'0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5'
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
'pad_topic': 0
},
},
'gpt2': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1',
'durecdial_gpt2.zip',
'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4'
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'cls': 101,
'sep': 102,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
'pad_topic': 0,
},
}
}

View File

@@ -0,0 +1 @@
from .hredial import HReDialDataset

View File

@@ -0,0 +1,209 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
r"""
ReDial
======
References:
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
.. _`"Towards deep conversational recommendations."`:
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
"""
import json
import os
import pickle as pkl
from copy import copy
from loguru import logger
import torch
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class HReDialDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""Specify tokenized resource and init base dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
dpath = os.path.join(DATASET_PATH, "hredial", tokenize)
super().__init__(opt, dpath, resource, restore, save)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity
}
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
# load train/valid/test data
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
def _load_other_data(self):
# edge extension data
self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8'))
# dbpedia
self.entity2id = json.load(
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb'))
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self.side_data
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
# 将文本、电影、实体信息转换为序号
def _convert_to_id(self, conversation):
augmented_convs = []
last_role = None
conv_id = conversation["conv_id"]
related_item = []
related_entity = []
related_word = []
for utt in conversation['dialog']:
self.unk_token_idx = 3
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind]
related_item += item_ids
related_entity += entity_ids
related_word += word_ids
if utt["role"] == last_role:
augmented_convs[-1]["text"] += text_token_ids
augmented_convs[-1]["item"] += item_ids
augmented_convs[-1]["entity"] += entity_ids
augmented_convs[-1]["word"] += word_ids
else:
augmented_convs.append({
"conv_id": conv_id,
"role": utt["role"],
"text": text_token_ids,
"item": related_item,
"entity": related_entity,
"word": related_word
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_items = [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
if len(context_tokens) > 0:
conv_dict = {
"conv_id": conv["conv_id"],
"role": conv["role"],
"tokens": context_tokens,
"response": text_tokens,
"item": context_items,
"entity": context_entities,
"word": context_words,
"items": movies
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts

View File

@@ -0,0 +1,66 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/1
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'nltk': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
'redial_nltk.zip',
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0
},
},
'bert': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
'redial_bert.zip',
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'gpt2': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
'redial_gpt2.zip',
'37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'sent_split': 4,
'word_split': 5,
'pad_entity': 0,
'pad_word': 0
},
}
}

View File

@@ -0,0 +1 @@
from .htgredial import HTGReDialDataset

View File

@@ -0,0 +1,210 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
r"""
ReDial
======
References:
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
.. _`"Towards deep conversational recommendations."`:
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
"""
import json
import os
import pickle as pkl
from copy import copy
from loguru import logger
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class HTGReDialDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""Specify tokenized resource and init base dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
dpath = os.path.join(DATASET_PATH, "htgredial", tokenize)
super().__init__(opt, dpath, resource, restore, save)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity
}
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
# load train/valid/test data
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
def _load_other_data(self):
# edge extension data
self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8'))
# hypergraph
self.user2hypergraph = json.load(open(os.path.join(self.dpath, 'user2hypergraph.json'), 'r', encoding='utf-8'))
# dbpedia
self.entity2id = json.load(
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb'))
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self.side_data
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
# 将文本、电影、实体信息转换为序号
def _convert_to_id(self, conversation):
augmented_convs = []
last_role = None
conv_id = conversation["conv_id"]
related_item = []
related_entity = []
related_word = []
for utt in conversation['dialog']:
self.unk_token_idx = 3
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind]
related_item += item_ids
related_entity += entity_ids
related_word += word_ids
if utt["role"] == last_role:
augmented_convs[-1]["text"] += text_token_ids
augmented_convs[-1]["item"] += item_ids
augmented_convs[-1]["entity"] += entity_ids
augmented_convs[-1]["word"] += word_ids
else:
augmented_convs.append({
"conv_id": conv_id,
"role": utt["role"],
"text": text_token_ids,
"item": related_item,
"entity": related_entity,
"word": related_word
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_items = [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
if len(context_tokens) > 0:
conv_dict = {
"conv_id": conv["conv_id"],
"role": conv["role"],
"tokens": copy(context_tokens),
"response": text_tokens,
"item": copy(context_items),
"entity": copy(context_entities),
"word": copy(context_words),
"items": movies
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts

View File

@@ -0,0 +1,66 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/1
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'pkuseg': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
'redial_nltk.zip',
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0
},
},
'bert': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
'redial_bert.zip',
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'gpt2': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
'redial_gpt2.zip',
'37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'sent_split': 4,
'word_split': 5,
'pad_entity': 0,
'pad_word': 0
},
}
}

View File

@@ -0,0 +1 @@
from .opendialkg import OpenDialKGDataset

View File

@@ -0,0 +1,286 @@
# @Time : 2020/12/19
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/12/20, 2021/1/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
r"""
OpenDialKG
==========
References:
Moon, Seungwhan, et al. `"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`_ in ACL 2019.
.. _`"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`:
https://www.aclweb.org/anthology/P19-1081/
"""
import json
import os
from collections import defaultdict
from copy import copy
from loguru import logger
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class OpenDialKGDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""Specify tokenized resource and init base dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
self.special_token_idx = resource['special_token_idx']
self.unk_token_idx = self.special_token_idx['unk']
dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize)
self.train_review = None
self.valid_review = None
self.test_review = None
super().__init__(opt, dpath, resource, restore, save)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'word2id': self.word2id,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity,
'n_word': self.n_word,
}
vocab.update(self.special_token_idx)
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
# load train/valid/test data
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
def _load_other_data(self):
# opendialkg
self.entity2id = json.load(
open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
# {head_entity_id: [(relation_id, tail_entity_id)]}
self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8')
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]")
# conceptnet
# {concept: concept_id}
self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))
self.n_word = max(self.word2id.values()) + 1
# {concept \t relation\t concept}
self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8')
logger.debug(
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self._side_data_process()
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
def _convert_to_id(self, idx, conversation):
augmented_convs = []
last_role = None
conv_id = conversation.get("conv_id", idx)
related_item = []
related_entity = []
related_word = []
for utt in conversation['dialog']:
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
related_item += item_ids
related_entity += entity_ids
related_word += word_ids
if utt["role"] == last_role:
augmented_convs[-1]["text"] += text_token_ids
augmented_convs[-1]["item"] += item_ids
augmented_convs[-1]["entity"] += entity_ids
augmented_convs[-1]["word"] += word_ids
else:
augmented_convs.append({
"conv_id": conv_id,
"role": utt["role"],
"text": text_token_ids,
"item": related_item,
"entity": related_entity,
"word": related_word
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_items = [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
if len(context_tokens) > 0:
conv_dict = {
"conv_id": conv["conv_id"],
"role": conv["role"],
"tokens": copy(context_tokens),
"response": text_tokens,
"item": copy(context_items),
"entity": copy(context_entities),
"word": copy(context_words),
"items": movies
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts
def _side_data_process(self):
processed_entity_kg = self._entity_kg_process()
logger.debug("[Finish entity KG process]")
processed_word_kg = self._word_kg_process()
logger.debug("[Finish word KG process]")
item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8'))
logger.debug('[Load item entity ids]')
side_data = {
"entity_kg": processed_entity_kg,
"word_kg": processed_word_kg,
"item_entity_ids": item_entity_ids,
}
return side_data
def _entity_kg_process(self):
edge_list = [] # [(entity, entity, relation)]
for line in self.entity_kg:
triple = line.strip().split('\t')
if len(triple) != 3 or triple[0] not in self.entity2id or triple[2] not in self.entity2id:
continue
e0 = self.entity2id[triple[0]]
e1 = self.entity2id[triple[2]]
r = triple[1]
edge_list.append((e0, e1, r))
# edge_list.append((e1, e0, r))
edge_list.append((e0, e0, 'SELF_LOOP'))
if e1 != e0:
edge_list.append((e1, e1, 'SELF_LOOP'))
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
for h, t, r in edge_list:
relation_cnt[r] += 1
for h, t, r in edge_list:
if relation_cnt[r] > 20000:
if r not in relation2id:
relation2id[r] = len(relation2id)
edges.add((h, t, relation2id[r]))
entities.add(self.id2entity[h])
entities.add(self.id2entity[t])
return {
'edge': list(edges),
'n_relation': len(relation2id),
'entity': list(entities)
}
def _word_kg_process(self):
edges = set() # {(entity, entity)}
entities = set()
for line in self.word_kg:
triple = line.strip().split('\t')
entities.add(triple[0])
entities.add(triple[2])
e0 = self.word2id[triple[0]]
e1 = self.word2id[triple[2]]
edges.add((e0, e1))
edges.add((e1, e0))
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
return {
'edge': list(edges),
'entity': list(entities)
}

View File

@@ -0,0 +1,66 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/21
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'nltk': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1',
'opendialkg_nltk.zip',
'6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f'
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'bert': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1',
'opendialkg_bert.zip',
'0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07'
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'gpt2': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1',
'opendialkg_gpt2.zip',
'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e'
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'sent_split': 4,
'word_split': 5,
'pad_entity': 0,
'pad_word': 0
},
}
}

View File

@@ -0,0 +1 @@
from .redial import ReDialDataset

View File

@@ -0,0 +1,269 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
r"""
ReDial
======
References:
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
.. _`"Towards deep conversational recommendations."`:
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
"""
import json
import os
from collections import defaultdict
from copy import copy
from loguru import logger
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class ReDialDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""Specify tokenized resource and init base dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
self.special_token_idx = resource['special_token_idx']
self.unk_token_idx = self.special_token_idx['unk']
dpath = os.path.join(DATASET_PATH, "redial", tokenize)
super().__init__(opt, dpath, resource, restore, save)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'word2id': self.word2id,
'vocab_size': len(self.tok2ind),
'n_entity': self.n_entity,
'n_word': self.n_word,
}
vocab.update(self.special_token_idx)
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
# load train/valid/test data
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
def _load_other_data(self):
# dbpedia
self.entity2id = json.load(
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
# {head_entity_id: [(relation_id, tail_entity_id)]}
self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8'))
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
# conceptNet
# {concept: concept_id}
self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8'))
self.n_word = max(self.word2id.values()) + 1
# {relation\t concept \t concept}
self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8')
logger.debug(
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self._side_data_process()
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._merge_conv_data(conversation["dialog"]) for conversation in tqdm(raw_data)]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
def _merge_conv_data(self, dialog):
augmented_convs = []
last_role = None
for utt in dialog:
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
if utt["role"] == last_role:
augmented_convs[-1]["text"] += text_token_ids
augmented_convs[-1]["movie"] += movie_ids
augmented_convs[-1]["entity"] += entity_ids
augmented_convs[-1]["word"] += word_ids
else:
augmented_convs.append({
"role": utt["role"],
"text": text_token_ids,
"entity": entity_ids,
"movie": movie_ids,
"word": word_ids
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_items = [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"]
if len(context_tokens) > 0:
conv_dict = {
"conv_id": i,
"role": conv['role'],
"tokens": copy(context_tokens),
"response": text_tokens,
"entity": copy(context_entities),
"word": copy(context_words),
"item": copy(context_items),
"items": movies,
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts
def _side_data_process(self):
processed_entity_kg = self._entity_kg_process()
logger.debug("[Finish entity KG process]")
processed_word_kg = self._word_kg_process()
logger.debug("[Finish word KG process]")
movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))
logger.debug('[Load movie entity ids]')
side_data = {
"entity_kg": processed_entity_kg,
"word_kg": processed_word_kg,
"item_entity_ids": movie_entity_ids,
}
return side_data
def _entity_kg_process(self, SELF_LOOP_ID=185):
edge_list = [] # [(entity, entity, relation)]
for entity in range(self.n_entity):
if str(entity) not in self.entity_kg:
continue
edge_list.append((entity, entity, SELF_LOOP_ID)) # add self loop
for tail_and_relation in self.entity_kg[str(entity)]:
if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID:
edge_list.append((entity, tail_and_relation[1], tail_and_relation[0]))
edge_list.append((tail_and_relation[1], entity, tail_and_relation[0]))
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
for h, t, r in edge_list:
relation_cnt[r] += 1
for h, t, r in edge_list:
if relation_cnt[r] > 1000:
if r not in relation2id:
relation2id[r] = len(relation2id)
edges.add((h, t, relation2id[r]))
entities.add(self.id2entity[h])
entities.add(self.id2entity[t])
return {
'edge': list(edges),
'n_relation': len(relation2id),
'entity': list(entities)
}
def _word_kg_process(self):
edges = set() # {(entity, entity)}
entities = set()
for line in self.word_kg:
kg = line.strip().split('\t')
entities.add(kg[1].split('/')[0])
entities.add(kg[2].split('/')[0])
e0 = self.word2id[kg[1].split('/')[0]]
e1 = self.word2id[kg[2].split('/')[0]]
edges.add((e0, e1))
edges.add((e1, e0))
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
return {
'edge': list(edges),
'entity': list(entities)
}

View File

@@ -0,0 +1,66 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/1
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'nltk': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
'redial_nltk.zip',
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0
},
},
'bert': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
'redial_bert.zip',
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
},
},
'gpt2': {
'version': '0.31',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
'redial_gpt2.zip',
'15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8',
),
'special_token_idx': {
'pad': -100,
'start': 1,
'end': 2,
'unk': 3,
'sent_split': 4,
'word_split': 5,
'pad_entity': 0,
'pad_word': 0
},
}
}

View File

@@ -0,0 +1 @@
from .tgredial import TGReDialDataset

View File

@@ -0,0 +1,71 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/4
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'pkuseg': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1',
'tgredial_pkuseg.zip',
'8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6',
),
'special_token_idx': {
'pad': 0,
'start': 1,
'end': 2,
'unk': 3,
'pad_entity': 0,
'pad_word': 0,
'pad_topic': 0
},
},
'bert': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1',
'tgredial_bert.zip',
'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b'
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
'pad_topic': 0
},
},
'gpt2': {
'version': '0.3',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1',
'tgredial_gpt2.zip',
'2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9'
),
'special_token_idx': {
'pad': 0,
'start': 101,
'end': 102,
'unk': 100,
'cls': 101,
'sep': 102,
'sent_split': 2,
'word_split': 3,
'pad_entity': 0,
'pad_word': 0,
'pad_topic': 0,
},
}
}

View File

@@ -0,0 +1,343 @@
# @Time : 2020/12/4
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/12/6, 2021/1/2, 2020/12/19
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email : francis_kun_zhou@163.com, sdzyh002@gmail
r"""
TGReDial
========
References:
Zhou, Kun, et al. `"Towards Topic-Guided Conversational Recommender System."`_ in COLING 2020.
.. _`"Towards Topic-Guided Conversational Recommender System."`:
https://www.aclweb.org/anthology/2020.coling-main.365/
"""
import json
import os
from collections import defaultdict
from copy import copy
import numpy as np
from loguru import logger
from tqdm import tqdm
from crslab.config import DATASET_PATH
from crslab.data.dataset.base import BaseDataset
from .resources import resources
class TGReDialDataset(BaseDataset):
"""
Attributes:
train_data: train dataset.
valid_data: valid dataset.
test_data: test dataset.
vocab (dict): ::
{
'tok2ind': map from token to index,
'ind2tok': map from index to token,
'topic2ind': map from topic to index,
'ind2topic': map from index to topic,
'entity2id': map from entity to index,
'id2entity': map from index to entity,
'word2id': map from word to index,
'vocab_size': len(self.tok2ind),
'n_topic': len(self.topic2ind) + 1,
'n_entity': max(self.entity2id.values()) + 1,
'n_word': max(self.word2id.values()) + 1,
}
Notes:
``'unk'`` and ``'pad_topic'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
"""
def __init__(self, opt, tokenize, restore=False, save=False):
"""Specify tokenized resource and init base dataset.
Args:
opt (Config or dict): config for dataset or the whole system.
tokenize (str): how to tokenize dataset.
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
save (bool): whether to save dataset after processing. Defaults to False.
"""
resource = resources[tokenize]
self.special_token_idx = resource['special_token_idx']
self.unk_token_idx = self.special_token_idx['unk']
self.pad_topic_idx = self.special_token_idx['pad_topic']
dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize)
self.replace_token = opt.get('replace_token',None)
self.replace_token_idx = opt.get('replace_token_idx',None)
super().__init__(opt, dpath, resource, restore, save)
if self.replace_token:
if self.replace_token_idx:
self.side_data["embedding"][self.replace_token_idx] = self.side_data['embedding'][0]
else:
self.side_data["embedding"] = np.insert(self.side_data["embedding"],len(self.side_data["embedding"]),self.side_data['embedding'][0],axis=0)
def _load_data(self):
train_data, valid_data, test_data = self._load_raw_data()
self._load_vocab()
self._load_other_data()
vocab = {
'tok2ind': self.tok2ind,
'ind2tok': self.ind2tok,
'topic2ind': self.topic2ind,
'ind2topic': self.ind2topic,
'entity2id': self.entity2id,
'id2entity': self.id2entity,
'word2id': self.word2id,
'vocab_size': len(self.tok2ind),
'n_topic': len(self.topic2ind) + 1,
'n_entity': self.n_entity,
'n_word': self.n_word,
}
vocab.update(self.special_token_idx)
return train_data, valid_data, test_data, vocab
def _load_raw_data(self):
# load train/valid/test data
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
valid_data = json.load(f)
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
return train_data, valid_data, test_data
def _load_vocab(self):
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
# add special tokens
if self.replace_token:
if self.replace_token not in self.tok2ind:
if self.replace_token_idx:
self.ind2tok[self.replace_token_idx] = self.replace_token
self.tok2ind[self.replace_token] = self.replace_token_idx
self.special_token_idx[self.replace_token] = self.replace_token_idx
else:
self.ind2tok[len(self.tok2ind)] = self.replace_token
self.tok2ind[self.replace_token] = len(self.tok2ind)
self.special_token_idx[self.replace_token] = len(self.tok2ind)-1
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8'))
self.ind2topic = {idx: word for word, idx in self.topic2ind.items()}
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]")
logger.debug(f"[The size of token2index dictionary is {len(self.topic2ind)}]")
logger.debug(f"[The size of index2token dictionary is {len(self.ind2topic)}]")
def _load_other_data(self):
# cn-dbpedia
self.entity2id = json.load(
open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id}
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
self.n_entity = max(self.entity2id.values()) + 1
# {head_entity_id: [(relation_id, tail_entity_id)]}
self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8')
logger.debug(
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]")
# hownet
# {concept: concept_id}
self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))
self.n_word = max(self.word2id.values()) + 1
# {relation\t concept \t concept}
self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8')
logger.debug(
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]")
# user interaction history dictionary
self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8'))
logger.debug(f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]")
# user profile
self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8'))
logger.debug(f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}")
def _data_preprocess(self, train_data, valid_data, test_data):
processed_train_data = self._raw_data_process(train_data)
logger.debug("[Finish train data process]")
processed_valid_data = self._raw_data_process(valid_data)
logger.debug("[Finish valid data process]")
processed_test_data = self._raw_data_process(test_data)
logger.debug("[Finish test data process]")
processed_side_data = self._side_data_process()
logger.debug("[Finish side data process]")
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
def _raw_data_process(self, raw_data):
augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]
augmented_conv_dicts = []
for conv in tqdm(augmented_convs):
augmented_conv_dicts.extend(self._augment_and_add(conv))
return augmented_conv_dicts
def _convert_to_id(self, conversation):
augmented_convs = []
last_role = None
for utt in conversation['messages']:
assert utt['role'] != last_role
# change movies into slots
if self.replace_token:
if len(utt['movie']) != 0:
while '' in utt['text'] :
begin = utt['text'].index("")
end = utt['text'].index("")
utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:]
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id]
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
policy = []
for action, kw in zip(utt['target'][1::2], utt['target'][2::2]):
if kw is None or action == '推荐电影':
continue
if isinstance(kw, str):
kw = [kw]
kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw]
policy.append([action, kw])
final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]]
final = [utt['final'][0], final_kws]
conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id'])
interaction_history = self.conv2history.get(conv_utt_id, [])
user_profile = self.user2profile[conversation['user_id']]
user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile]
augmented_convs.append({
"role": utt["role"],
"text": text_token_ids,
"entity": entity_ids,
"movie": movie_ids,
"word": text_token_ids,
'policy': policy,
'final': final,
'interaction_history': interaction_history,
'user_profile': user_profile
})
last_role = utt["role"]
return augmented_convs
def _augment_and_add(self, raw_conv_dict):
augmented_conv_dicts = []
context_tokens, context_entities, context_words, context_policy, context_items = [], [], [], [], []
entity_set, word_set = set(), set()
for i, conv in enumerate(raw_conv_dict):
text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \
conv['policy']
if self.replace_token is not None:
if text_tokens.count(30000) != len(movies):
continue # the number of slots doesn't equal to the number of movies
if len(context_tokens) > 0:
conv_dict = {
'conv_id': i,
'role': conv['role'],
'user_profile': conv['user_profile'],
"tokens": copy(context_tokens),
"response": text_tokens,
"entity": copy(context_entities),
"word": copy(context_words),
'interaction_history': conv['interaction_history'],
'item': copy(context_items),
"items": movies,
'context_policy': copy(context_policy),
'target': policies,
'final': conv['final'],
}
augmented_conv_dicts.append(conv_dict)
context_tokens.append(text_tokens)
context_policy.append(policies)
context_items += movies
for entity in entities + movies:
if entity not in entity_set:
entity_set.add(entity)
context_entities.append(entity)
for word in words:
if word not in word_set:
word_set.add(word)
context_words.append(word)
return augmented_conv_dicts
def _side_data_process(self):
processed_entity_kg = self._entity_kg_process()
logger.debug("[Finish entity KG process]")
processed_word_kg = self._word_kg_process()
logger.debug("[Finish word KG process]")
movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))
logger.debug('[Load movie entity ids]')
side_data = {
"entity_kg": processed_entity_kg,
"word_kg": processed_word_kg,
"item_entity_ids": movie_entity_ids,
}
return side_data
def _entity_kg_process(self):
edge_list = [] # [(entity, entity, relation)]
for line in self.entity_kg:
triple = line.strip().split('\t')
e0 = self.entity2id[triple[0]]
e1 = self.entity2id[triple[2]]
r = triple[1]
edge_list.append((e0, e1, r))
edge_list.append((e1, e0, r))
edge_list.append((e0, e0, 'SELF_LOOP'))
if e1 != e0:
edge_list.append((e1, e1, 'SELF_LOOP'))
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
for h, t, r in edge_list:
relation_cnt[r] += 1
for h, t, r in edge_list:
if r not in relation2id:
relation2id[r] = len(relation2id)
edges.add((h, t, relation2id[r]))
entities.add(self.id2entity[h])
entities.add(self.id2entity[t])
return {
'edge': list(edges),
'n_relation': len(relation2id),
'entity': list(entities)
}
def _word_kg_process(self):
edges = set() # {(entity, entity)}
entities = set()
for line in self.word_kg:
triple = line.strip().split('\t')
entities.add(triple[0])
entities.add(triple[2])
e0 = self.word2id[triple[0]]
e1 = self.word2id[triple[2]]
edges.add((e0, e1))
edges.add((e1, e0))
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
return {
'edge': list(edges),
'entity': list(entities)
}

275
HyCoRec/crslab/download.py Normal file
View File

@@ -0,0 +1,275 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/7
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/7
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
import hashlib
import os
import shutil
import time
import datetime
import requests
import tqdm
from loguru import logger
class DownloadableFile:
"""
A class used to abstract any file that has to be downloaded online.
Any task that needs to download a file needs to have a list RESOURCES
that have objects of this class as elements.
This class provides the following functionality:
- Download a file from a URL
- Untar the file if zipped
- Checksum for the downloaded file
An object of this class needs to be created with:
- url <string> : URL or Google Drive id to download from
- file_name <string> : File name that the file should be named
- hashcode <string> : SHA256 hashcode of the downloaded file
- zipped <boolean> : False if the file is not compressed
- from_google <boolean> : True if the file is from Google Drive
"""
def __init__(self, url, file_name, hashcode, zipped=True, from_google=False):
self.url = url
self.file_name = file_name
self.hashcode = hashcode
self.zipped = zipped
self.from_google = from_google
def checksum(self, dpath):
"""
Checksum on a given file.
:param dpath: path to the downloaded file.
"""
sha256_hash = hashlib.sha256()
with open(os.path.join(dpath, self.file_name), "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
if sha256_hash.hexdigest() != self.hashcode:
# remove_dir(dpath)
raise AssertionError(
f"[ Checksum for {self.file_name} from \n{self.url}\n"
"does not match the expected checksum. Please try again. ]"
)
else:
logger.debug("Checksum Successful")
pass
def download_file(self, dpath):
if self.from_google:
download_from_google_drive(self.url, os.path.join(dpath, self.file_name))
else:
download(self.url, dpath, self.file_name)
self.checksum(dpath)
if self.zipped:
untar(dpath, self.file_name)
def download(url, path, fname, redownload=False, num_retries=5):
"""
Download file using `requests`.
If ``redownload`` is set to false, then will not download tar file again if it is
present (default ``False``).
"""
outfile = os.path.join(path, fname)
download = not os.path.exists(outfile) or redownload
logger.info(f"Downloading {url} to {outfile}")
retry = num_retries
exp_backoff = [2 ** r for r in reversed(range(retry))]
pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname))
while download and retry > 0:
response = None
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.60',
}
response = requests.get(url, stream=True, headers=headers)
# negative reply could be 'none' or just missing
CHUNK_SIZE = 32768
total_size = int(response.headers.get('Content-Length', -1))
# server returns remaining size if resuming, so adjust total
pbar.total = total_size
done = 0
with open(outfile, 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if total_size > 0:
done += len(chunk)
if total_size < done:
# don't freak out if content-length was too small
total_size = done
pbar.total = total_size
pbar.update(len(chunk))
break
except (
requests.exceptions.ConnectionError,
requests.exceptions.ReadTimeout,
):
retry -= 1
pbar.clear()
if retry > 0:
pl = 'y' if retry == 1 else 'ies'
logger.debug(
f'Connection error, retrying. ({retry} retr{pl} left)'
)
time.sleep(exp_backoff[retry])
else:
logger.error('Retried too many times, stopped retrying.')
finally:
if response:
response.close()
if retry <= 0:
raise RuntimeError('Connection broken too many times. Stopped retrying.')
if download and retry > 0:
pbar.update(done - pbar.n)
if done < total_size:
raise RuntimeError(
f'Received less data than specified in Content-Length header for '
f'{url}. There may be a download problem.'
)
pbar.close()
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def download_from_google_drive(gd_id, destination):
"""
Use the requests package to download a file from Google Drive.
"""
URL = 'https://docs.google.com/uc?export=download'
with requests.Session() as session:
response = session.get(URL, params={'id': gd_id}, stream=True)
token = _get_confirm_token(response)
if token:
response.close()
params = {'id': gd_id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
CHUNK_SIZE = 32768
with open(destination, 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
response.close()
def move(path1, path2):
"""
Rename the given file.
"""
shutil.move(path1, path2)
def untar(path, fname, deleteTar=True):
"""
Unpack the given archive file to the same directory.
:param str path:
The folder containing the archive. Will contain the contents.
:param str fname:
The filename of the archive file.
:param bool deleteTar:
If true, the archive will be deleted after extraction.
"""
logger.debug(f'unpacking {fname}')
fullpath = os.path.join(path, fname)
shutil.unpack_archive(fullpath, path)
if deleteTar:
os.remove(fullpath)
def make_dir(path):
"""
Make the directory and any nonexistent parent directories (`mkdir -p`).
"""
# the current working directory is a fine path
if path != '':
os.makedirs(path, exist_ok=True)
def remove_dir(path):
"""
Remove the given directory, if it exists.
"""
shutil.rmtree(path, ignore_errors=True)
def check_build(path, version_string=None):
"""
Check if '.built' flag has been set for that task.
If a version_string is provided, this has to match, or the version is regarded as
not built.
"""
if version_string:
fname = os.path.join(path, '.built')
if not os.path.isfile(fname):
return False
else:
with open(fname, 'r') as read:
text = read.read().split('\n')
return len(text) > 1 and text[1] == version_string
else:
return os.path.isfile(os.path.join(path, '.built'))
def mark_done(path, version_string=None):
"""
Mark this path as prebuilt.
Marks the path as done by adding a '.built' file with the current timestamp
plus a version description string if specified.
:param str path:
The file path to mark as built.
:param str version_string:
The version of this dataset.
"""
with open(os.path.join(path, '.built'), 'w') as write:
write.write(str(datetime.datetime.today()))
if version_string:
write.write('\n' + version_string)
def build(dpath, dfile, version=None):
if not check_build(dpath, version):
logger.info('[Building data: ' + dpath + ']')
if check_build(dpath):
remove_dir(dpath)
make_dir(dpath)
# Download the data.
downloadable_file = dfile
downloadable_file.download_file(dpath)
mark_done(dpath, version)

View File

@@ -0,0 +1,28 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/22
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from loguru import logger
from .standard import StandardEvaluator
from ..data import dataset_language_map
Evaluator_register_table = {
'standard': StandardEvaluator
}
def get_evaluator(evaluator_name, dataset, file_path):
if evaluator_name in Evaluator_register_table:
language = dataset_language_map[dataset]
evaluator = Evaluator_register_table[evaluator_name](language, file_path)
logger.info(f'[Build evaluator {evaluator_name}]')
return evaluator
else:
raise NotImplementedError(f'Model [{evaluator_name}] has not been implemented')

View File

@@ -0,0 +1,31 @@
# @Time : 2020/11/30
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
# UPDATE:
# @Time : 2020/11/30
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
from abc import ABC, abstractmethod
class BaseEvaluator(ABC):
"""Base class for evaluator"""
def rec_evaluate(self, preds, label):
pass
def gen_evaluate(self, preds, label):
pass
def policy_evaluate(self, preds, label):
pass
@abstractmethod
def report(self, epoch, mode):
pass
@abstractmethod
def reset_metrics(self):
pass

View File

@@ -0,0 +1,30 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/18
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/18
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
resources = {
'zh': {
'version': '0.2',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1',
'cc.zh.300.zip',
'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3'
)
},
'en': {
'version': '0.2',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1',
'cc.en.300.zip',
'96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e'
)
}
}

View File

@@ -0,0 +1,4 @@
from .base import Metric, Metrics, aggregate_unnamed_reports, AverageMetric
from .gen import BleuMetric, ExactMatchMetric, F1Metric, DistMetric, EmbeddingAverage, VectorExtrema, \
GreedyMatch
from .rec import RECMetric, NDCGMetric, MRRMetric, CovMetric

View File

@@ -0,0 +1,232 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
import functools
from abc import ABC, abstractmethod
import torch
from typing import Any, Union, List, Optional, Dict
TScalar = Union[int, float, torch.Tensor]
TVector = Union[List[TScalar], torch.Tensor]
@functools.total_ordering
class Metric(ABC):
"""
Base class for storing metrics.
Subclasses should define .value(). Examples are provided for each subclass.
"""
@abstractmethod
def value(self):
"""
Return the value of the metric as a float.
"""
pass
@abstractmethod
def __add__(self, other: Any) -> 'Metric':
raise NotImplementedError
def __iadd__(self, other):
return self.__radd__(other)
def __radd__(self, other: Any):
if other is None:
return self
return self.__add__(other)
def __str__(self) -> str:
return f'{self.value():.4g}'
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.value():.4g})'
def __float__(self) -> float:
return float(self.value())
def __int__(self) -> int:
return int(self.value())
def __eq__(self, other: Any) -> bool:
if isinstance(other, Metric):
return self.value() == other.value()
else:
return self.value() == other
def __lt__(self, other: Any) -> bool:
if isinstance(other, Metric):
return self.value() < other.value()
else:
return self.value() < other
def __sub__(self, other: Any) -> float:
"""
Used heavily for assertAlmostEqual.
"""
if not isinstance(other, float):
raise TypeError('Metrics.__sub__ is intentionally limited to floats.')
return self.value() - other
def __rsub__(self, other: Any) -> float:
"""
Used heavily for assertAlmostEqual.
NOTE: This is not necessary in python 3.7+.
"""
if not isinstance(other, float):
raise TypeError('Metrics.__rsub__ is intentionally limited to floats.')
return other - self.value()
@classmethod
def as_number(cls, obj: TScalar) -> Union[int, float]:
if isinstance(obj, torch.Tensor):
obj_as_number: Union[int, float] = obj.item()
else:
obj_as_number = obj # type: ignore
assert isinstance(obj_as_number, int) or isinstance(obj_as_number, float)
return obj_as_number
@classmethod
def as_float(cls, obj: TScalar) -> float:
return float(cls.as_number(obj))
@classmethod
def as_int(cls, obj: TScalar) -> int:
return int(cls.as_number(obj))
@classmethod
def many(cls, *objs: List[TVector]) -> List['Metric']:
"""
Construct many of a Metric from the base parts.
Useful if you separately compute numerators and denomenators, etc.
"""
lengths = [len(o) for o in objs]
if len(set(lengths)) != 1:
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
return [cls(*items) for items in zip(*objs)]
class SumMetric(Metric):
"""
Class that keeps a running sum of some metric.
Examples of SumMetric include things like "exs", the number of examples seen since
the last report, which depends exactly on a teacher.
"""
__slots__ = ('_sum',)
def __init__(self, sum_: TScalar = 0):
if isinstance(sum_, torch.Tensor):
self._sum = sum_.item()
else:
assert isinstance(sum_, (int, float, list))
self._sum = sum_
def __add__(self, other: Optional['SumMetric']) -> 'SumMetric':
# NOTE: hinting can be cleaned up with "from __future__ import annotations" when
# we drop Python 3.6
if other is None:
return self
full_sum = self._sum + other._sum
# always keep the same return type
return type(self)(sum_=full_sum)
def value(self) -> float:
return self._sum
class AverageMetric(Metric):
"""
Class that keeps a running average of some metric.
Examples of AverageMetrics include hits@1, F1, accuracy, etc. These metrics all have
per-example values that can be directly mapped back to a teacher.
"""
__slots__ = ('_numer', '_denom')
def __init__(self, numer: TScalar, denom: TScalar = 1):
self._numer = self.as_number(numer)
self._denom = self.as_number(denom)
def __add__(self, other: Optional['AverageMetric']) -> 'AverageMetric':
# NOTE: hinting can be cleaned up with "from __future__ import annotations" when
# we drop Python 3.6
if other is None:
return self
full_numer: TScalar = self._numer + other._numer
full_denom: TScalar = self._denom + other._denom
# always keep the same return type
return type(self)(numer=full_numer, denom=full_denom)
def value(self) -> float:
if self._numer == 0 and self._denom == 0:
# don't nan out if we haven't counted anything
return 0.0
if self._denom == 0:
return float('nan')
return self._numer / self._denom
def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Metric]:
"""
Combines metrics without regard for tracking provenence.
"""
m: Dict[str, Metric] = {}
for task_report in reports:
for each_metric, value in task_report.items():
m[each_metric] = m.get(each_metric) + value
return m
class Metrics(object):
"""
Metrics aggregator.
"""
def __init__(self):
self._data = {}
def __str__(self):
return str(self._data)
def __repr__(self):
return f'Metrics({repr(self._data)})'
def get(self, key: str):
if key in self._data.keys():
return self._data[key].value()
else:
raise
def __getitem__(self, item):
return self.get(item)
def add(self, key: str, value: Optional[Metric]) -> None:
"""
Record an accumulation to a metric.
"""
self._data[key] = self._data.get(key) + value
def report(self):
"""
Report the metrics over all data seen so far.
"""
return {k: (v if "cov" not in k else SumMetric(len(set(v.value())))) for k, v in self._data.items()}
def clear(self):
"""
Clear all the metrics.
"""
self._data.clear()

View File

@@ -0,0 +1,158 @@
# @Time : 2020/11/30
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
# UPDATE:
# @Time : 2020/12/18
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
import re
from collections import Counter
import math
import numpy as np
from nltk import ngrams
from nltk.translate.bleu_score import sentence_bleu
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Optional
from crslab.evaluator.metrics.base import AverageMetric, SumMetric
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
re_space = re.compile(r'\s+')
class PPLMetric(AverageMetric):
def value(self):
return math.exp(super().value())
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = re_space.sub(' ', s)
# s = ' '.join(s.split())
return s
class ExactMatchMetric(AverageMetric):
@staticmethod
def compute(guess: str, answers: List[str]) -> 'ExactMatchMetric':
if guess is None or answers is None:
return None
for a in answers:
if guess == a:
return ExactMatchMetric(1)
return ExactMatchMetric(0)
class F1Metric(AverageMetric):
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return f1
@staticmethod
def compute(guess: str, answers: List[str]) -> 'F1Metric':
if guess is None or answers is None:
return AverageMetric(0, 0)
g_tokens = guess.split()
scores = [
F1Metric._prec_recall_f1_score(g_tokens, a.split())
for a in answers
]
return F1Metric(max(scores), 1)
class BleuMetric(AverageMetric):
@staticmethod
def compute(guess: str, answers: List[str], k: int) -> Optional['BleuMetric']:
"""
Compute approximate BLEU score between guess and a set of answers.
"""
weights = [0] * 4
weights[k - 1] = 1
score = sentence_bleu(
[a.split(" ") for a in answers],
guess.split(" "),
weights=weights,
)
return BleuMetric(score)
class DistMetric(SumMetric):
@staticmethod
def compute(sent: str, k: int) -> 'DistMetric':
token_set = set()
for token in ngrams(sent.split(), k):
token_set.add(token)
return DistMetric(len(token_set))
class EmbeddingAverage(AverageMetric):
@staticmethod
def _avg_embedding(embedding):
return np.sum(embedding, axis=0) / (np.linalg.norm(np.sum(embedding, axis=0)) + 1e-12)
@staticmethod
def compute(hyp_embedding, ref_embeddings) -> 'EmbeddingAverage':
hyp_avg_emb = EmbeddingAverage._avg_embedding(hyp_embedding).reshape(1, -1)
ref_avg_embs = [EmbeddingAverage._avg_embedding(emb) for emb in ref_embeddings]
ref_avg_embs = np.array(ref_avg_embs)
return EmbeddingAverage(float(cosine_similarity(hyp_avg_emb, ref_avg_embs).max()))
class VectorExtrema(AverageMetric):
@staticmethod
def _extreme_embedding(embedding):
max_emb = np.max(embedding, axis=0)
min_emb = np.min(embedding, axis=0)
extreme_emb = np.fromiter(
map(lambda x, y: x if ((x > y or x < -y) and y > 0) or ((x < y or x > -y) and y < 0) else y, max_emb,
min_emb), dtype=float)
return extreme_emb
@staticmethod
def compute(hyp_embedding, ref_embeddings) -> 'VectorExtrema':
hyp_ext_emb = VectorExtrema._extreme_embedding(hyp_embedding).reshape(1, -1)
ref_ext_embs = [VectorExtrema._extreme_embedding(emb) for emb in ref_embeddings]
ref_ext_embs = np.asarray(ref_ext_embs)
return VectorExtrema(float(cosine_similarity(hyp_ext_emb, ref_ext_embs).max()))
class GreedyMatch(AverageMetric):
@staticmethod
def compute(hyp_embedding, ref_embeddings) -> 'GreedyMatch':
hyp_emb = np.asarray(hyp_embedding)
ref_embs = (np.asarray(ref_embedding) for ref_embedding in ref_embeddings)
score_max = 0
for ref_emb in ref_embs:
sim_mat = cosine_similarity(hyp_emb, ref_emb)
score_max = max(score_max, (sim_mat.max(axis=0).mean() + sim_mat.max(axis=1).mean()) / 2)
return GreedyMatch(score_max)

View File

@@ -0,0 +1,41 @@
# @Time : 2020/11/30
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
# UPDATE:
# @Time : 2020/12/2
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
import math
from crslab.evaluator.metrics.base import AverageMetric, SumMetric
class RECMetric(AverageMetric):
@staticmethod
def compute(ranks, label, k) -> 'RECMetric':
return RECMetric(int(label in ranks[:k]))
class NDCGMetric(AverageMetric):
@staticmethod
def compute(ranks, label, k) -> 'NDCGMetric':
if label in ranks[:k]:
label_rank = ranks.index(label)
return NDCGMetric(1 / math.log2(label_rank + 2))
return NDCGMetric(0)
class MRRMetric(AverageMetric):
@staticmethod
def compute(ranks, label, k) -> 'MRRMetric':
if label in ranks[:k]:
label_rank = ranks.index(label)
return MRRMetric(1 / (label_rank + 1))
return MRRMetric(0)
class CovMetric(SumMetric):
@staticmethod
def compute(ranks, label, k) -> 'CovMetric':
return CovMetric(ranks[:k])

View File

@@ -0,0 +1,86 @@
# @Time : 2020/11/30
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
# UPDATE:
# @Time : 2020/12/18
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
import os
import json
from collections import defaultdict
from time import perf_counter
from loguru import logger
from nltk import ngrams
from crslab.evaluator.base import BaseEvaluator
from crslab.evaluator.utils import nice_report
from .metrics import *
class StandardEvaluator(BaseEvaluator):
"""The evaluator for all kind of model(recommender, conversation, policy)
Args:
rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K
dist_set: the set to record dist n-gram
dist_cnt: the count of dist n-gram evaluation
gen_metrics: the metrics to evaluate conversational model, including bleu, dist, embedding metrics, f1
optim_metrics: the metrics to optimize in training
"""
def __init__(self, language, file_path=None):
super(StandardEvaluator, self).__init__()
self.file_path = file_path
self.result_data = []
# rec
self.rec_metrics = Metrics()
# gen
self.dist_set = defaultdict(set)
self.dist_cnt = 0
self.gen_metrics = Metrics()
# optim
self.optim_metrics = Metrics()
def rec_evaluate(self, ranks, label):
for k in [1, 10, 50]:
if len(ranks) >= k:
self.rec_metrics.add(f"recall@{k}", RECMetric.compute(ranks, label, k))
self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k))
self.rec_metrics.add(f"mrr@{k}", MRRMetric.compute(ranks, label, k))
def gen_evaluate(self, hyp, refs, seq=None):
if hyp:
self.gen_metrics.add("f1", F1Metric.compute(hyp, refs))
for k in range(1, 5):
self.gen_metrics.add(f"bleu@{k}", BleuMetric.compute(hyp, refs, k))
for token in ngrams(seq, k):
self.dist_set[f"dist@{k}"].add(token)
self.dist_cnt += 1
def report(self, epoch=-1, mode='test'):
for k, v in self.dist_set.items():
self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt))
reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()]
all_reports = aggregate_unnamed_reports(reports)
self.result_data.append({
'epoch': epoch,
'mode': mode,
'report': {k:all_reports[k].value() for k in all_reports}
})
if self.file_path:
json.dump(self.result_data, open(self.file_path, "w", encoding="utf-8"), indent=4, ensure_ascii=False)
logger.info('\n' + nice_report(all_reports))
def reset_metrics(self):
# rec
self.rec_metrics.clear()
# conv
self.gen_metrics.clear()
self.dist_cnt = 0
self.dist_set.clear()
# optim
self.optim_metrics.clear()

View File

@@ -0,0 +1,160 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/17
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/12/17
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
import json
import re
import shutil
from collections import OrderedDict
import math
import torch
from typing import Union, Tuple
from .metrics import Metric
def _line_width():
try:
# if we're in an interactive ipython notebook, hardcode a longer width
__IPYTHON__ # type: ignore
return 128
except NameError:
return shutil.get_terminal_size((88, 24)).columns
def float_formatter(f: Union[float, int]) -> str:
"""
Format a float as a pretty string.
"""
if f != f:
# instead of returning nan, return "" so it shows blank in table
return ""
if isinstance(f, int):
# don't do any rounding of integers, leave them alone
return str(f)
if f >= 1000:
# numbers > 1000 just round to the nearest integer
s = f'{f:.0f}'
else:
# otherwise show 4 significant figures, regardless of decimal spot
s = f'{f:.4g}'
# replace leading 0's with blanks for easier reading
# example: -0.32 to -.32
s = s.replace('-0.', '-.')
if s.startswith('0.'):
s = s[1:]
# Add the trailing 0's to always show 4 digits
# example: .32 to .3200
if s[0] == '.' and len(s) < 5:
s += '0' * (5 - len(s))
return s
def round_sigfigs(x: Union[float, 'torch.Tensor'], sigfigs=4) -> float:
"""
Round value to specified significant figures.
:param x: input number
:param sigfigs: number of significant figures to return
:returns: float number rounded to specified sigfigs
"""
x_: float
if isinstance(x, torch.Tensor):
x_ = x.item()
else:
x_ = x # type: ignore
try:
if x_ == 0:
return 0
return round(x_, -math.floor(math.log10(abs(x_)) - sigfigs + 1))
except (ValueError, OverflowError) as ex:
if x_ in [float('inf'), float('-inf')] or x_ != x_: # inf or nan
return x_
else:
raise ex
def _report_sort_key(report_key: str) -> Tuple[str, str]:
"""
Sorting name for reports.
Sorts by main metric alphabetically, then by task.
"""
# if metric is on its own, like "f1", we will return ('', 'f1')
# if metric is from multitask, we denote it.
# e.g. "convai2/f1" -> ('convai2', 'f1')
# we handle multiple cases of / because sometimes teacher IDs have
# filenames.
fields = report_key.split("/")
main_key = fields.pop(-1)
sub_key = '/'.join(fields)
return (sub_key or 'all', main_key)
def nice_report(report) -> str:
"""
Render an agent Report as a beautiful string.
If pandas is installed, we will use it to render as a table. Multitask
metrics will be shown per row, e.g.
.. code-block:
f1 ppl
all .410 27.0
task1 .400 32.0
task2 .420 22.0
If pandas is not available, we will use a dict with like-metrics placed
next to each other.
"""
if not report:
return ""
try:
import pandas as pd
use_pandas = True
except ImportError:
use_pandas = False
sorted_keys = sorted(report.keys(), key=_report_sort_key)
output: OrderedDict[Union[str, Tuple[str, str]], float] = OrderedDict()
for k in sorted_keys:
v = report[k]
if isinstance(v, Metric):
v = v.value()
if use_pandas:
output[_report_sort_key(k)] = v
else:
output[k] = v
if use_pandas:
line_width = _line_width()
df = pd.DataFrame([output])
df.columns = pd.MultiIndex.from_tuples(df.columns)
df = df.stack().transpose().droplevel(0, axis=1)
result = " " + df.to_string(
na_rep="",
line_width=line_width - 3, # -3 for the extra spaces we add
float_format=float_formatter,
index=df.shape[0] > 1,
).replace("\n\n", "\n").replace("\n", "\n ")
result = re.sub(r"\s+$", "", result)
return result
else:
return json.dumps(
{
k: round_sigfigs(v, 4) if isinstance(v, float) else v
for k, v in output.items()
}
)

View File

@@ -0,0 +1,34 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/24
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
# @Time : 2021/10/06
# @Author : Zhipeng Zhao
# @Email : oran_official@outlook.com
import torch
from loguru import logger
from .crs import *
Model_register_table = {
'HyCoRec': HyCoRecModel,
}
def get_model(config, model_name, device, vocab, side_data=None):
if model_name in Model_register_table:
model = Model_register_table[model_name](config, device, vocab, side_data)
logger.info(f'[Build model {model_name}]')
if config.opt["gpu"] == [-1]:
return model
else:
return torch.nn.DataParallel(model, device_ids=config["gpu"])
else:
raise NotImplementedError('Model [{}] has not been implemented'.format(model_name))

View File

@@ -0,0 +1,62 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/29
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
from abc import ABC, abstractmethod
from torch import nn
from crslab.download import build
class BaseModel(ABC, nn.Module):
"""Base class for all models"""
def __init__(self, opt, device, dpath=None, resource=None):
super(BaseModel, self).__init__()
self.opt = opt
self.device = device
if resource is not None:
self.dpath = dpath
dfile = resource['file']
build(dpath, dfile, version=resource['version'])
self.build_model()
@abstractmethod
def build_model(self, *args, **kwargs):
"""build model"""
pass
def recommend(self, batch, mode):
"""calculate loss and prediction of recommendation for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass
def converse(self, batch, mode):
"""calculate loss and prediction of conversation for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass
def guide(self, batch, mode):
"""calculate loss and prediction of guidance for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass

View File

@@ -0,0 +1 @@
from .hycorec import *

View File

@@ -0,0 +1 @@
from .hycorec import HyCoRecModel

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# @Time : 2021/4/1
# @Author : Chenzhan Shang
# @Email : czshang@outlook.com
import torch
import torch.nn as nn
import torch.nn.functional as F
# Multi-head
class MHItemAttention(nn.Module):
def __init__(self, dim, head_num):
super(MHItemAttention, self).__init__()
self.MHA = torch.nn.MultiheadAttention(dim, head_num, batch_first=True)
def forward(self, related_entity, context_entity):
"""
input:
related_entity: (n_r, dim)
context_entity: (n_c, dim)
output:
related_context_entity: (n_c, dim)
"""
context_entity = torch.unsqueeze(context_entity, 0)
related_entity = torch.unsqueeze(related_entity, 0)
output, _ = self.MHA(context_entity, related_entity, related_entity)
return torch.squeeze(output, 0)

View File

@@ -0,0 +1,200 @@
import torch
import numpy as np
from torch import nn
from crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _normalize, create_position_codes
class TransformerDecoderLayerKG(nn.Module):
def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
):
super().__init__()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.dropout = nn.Dropout(p=dropout)
self.self_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm_self_attention = nn.LayerNorm(embedding_size)
self.session_item_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm_session_item_attention = nn.LayerNorm(embedding_size)
self.related_encoder_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm_related_encoder_attention = nn.LayerNorm(embedding_size)
self.context_encoder_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm_context_encoder_attention = nn.LayerNorm(embedding_size)
self.norm_merge = nn.LayerNorm(embedding_size)
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
self.norm3 = nn.LayerNorm(embedding_size)
def forward(self, x, related_encoder_output, related_encoder_mask, context_encoder_output, context_encoder_mask,
session_embedding, session_mask):
decoder_mask = self._create_selfattn_mask(x)
# first self attn
residual = x
# don't peak into the future!
x = self.self_attention(query=x, mask=decoder_mask)
x = self.dropout(x) # --dropout
x = x + residual
x = _normalize(x, self.norm_self_attention)
residual = x
x = self.session_item_attention(
query=x,
key=session_embedding,
value=session_embedding,
mask=session_mask
)
x = self.dropout(x)
x = residual + x
x = _normalize(x, self.norm_session_item_attention)
residual = x
related_x = self.related_encoder_attention(
query=x,
key=related_encoder_output,
value=related_encoder_output,
mask=related_encoder_mask
)
related_x = self.dropout(related_x) # --dropout
context_x = self.context_encoder_attention(
query=x,
key=context_encoder_output,
value=context_encoder_output,
mask=context_encoder_mask
)
context_x = self.dropout(context_x) # --dropout
x = related_x * 0.1 + context_x * 0.9 + residual
x = _normalize(x, self.norm_merge)
# finally the ffn
residual = x
x = self.ffn(x)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm3)
return x
def _create_selfattn_mask(self, x):
# figure out how many timestamps we need
bsz = x.size(0)
time = x.size(1)
# make sure that we don't look into the future
mask = torch.tril(x.new(time, time).fill_(1))
# broadcast across batch
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
return mask
class TransformerDecoderKG(nn.Module):
"""
Transformer Decoder layer.
:param int n_heads: the number of multihead attention heads.
:param int n_layers: number of transformer layers.
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
:param int ffn_size: the size of the hidden layer in the FFN
:param embedding: an embedding matrix for the bottom layer of the transformer.
If none, one is created for this encoder.
:param float dropout: Dropout used around embeddings and before layer
layer normalizations. This is used in Vaswani 2017 and works well on
large datasets.
:param float attention_dropout: Dropout performed after the multhead attention
softmax. This is not used in Vaswani 2017.
:param int padding_idx: Reserved padding index in the embeddings matrix.
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
used. If on, position embeddings are learned from scratch.
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
Found useful in fairseq.
:param int n_positions: Size of the position embeddings matrix.
"""
def __init__(
self,
n_heads,
n_layers,
embedding_size,
ffn_size,
vocabulary_size,
embedding=None,
dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
embeddings_scale=True,
learn_positional_embeddings=False,
padding_idx=None,
n_positions=1024,
):
super().__init__()
self.embedding_size = embedding_size
self.ffn_size = ffn_size
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = embedding_size
self.embeddings_scale = embeddings_scale
self.dropout = nn.Dropout(p=dropout) # --dropout
self.out_dim = embedding_size
assert embedding_size % n_heads == 0, \
'Transformer embedding size must be a multiple of n_heads'
self.embeddings = embedding
# create the positional embeddings
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
if not learn_positional_embeddings:
create_position_codes(
n_positions, embedding_size, out=self.position_embeddings.weight
)
else:
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
# build the model
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(TransformerDecoderLayerKG(
n_heads, embedding_size, ffn_size,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
dropout=dropout,
))
def forward(self, input, related_encoder_state, context_encoder_state, session_state, incr_state=None):
related_encoder_output, related_encoder_mask = related_encoder_state
context_encoder_output, context_encoder_mask = context_encoder_state
session_embedding, session_mask = session_state
seq_len = input.shape[1]
positions = input.new_empty(seq_len).long()
positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len)
tensor = self.embeddings(input)
if self.embeddings_scale:
tensor = tensor * np.sqrt(self.dim)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor = self.dropout(tensor) # --dropout
for layer in self.layers:
tensor = layer(tensor, related_encoder_output, related_encoder_mask, context_encoder_output,
context_encoder_mask, session_embedding, session_mask)
return tensor, None

View File

@@ -0,0 +1,729 @@
# -*- encoding: utf-8 -*-
# @Time : 2021/5/26
# @Author : Chenzhan Shang
# @email : czshang@outlook.com
r"""
PCR
====
References:
Chen, Qibin, et al. `"Towards Knowledge-Based Recommender Dialog System."`_ in EMNLP 2019.
.. _`"Towards Knowledge-Based Recommender Dialog System."`:
https://www.aclweb.org/anthology/D19-1189/
"""
import json
import os.path
import random
import pickle
from typing import List
from time import perf_counter
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from tqdm import tqdm
from torch_geometric.nn import RGCNConv, HypergraphConv
from crslab.config import DATA_PATH, DATASET_PATH
from crslab.model.base import BaseModel
from crslab.model.crs.hycorec.attention import MHItemAttention
from crslab.model.utils.functions import edge_to_pyg_format
from crslab.model.utils.modules.attention import SelfAttentionBatch, SelfAttentionSeq
from crslab.model.utils.modules.transformer import TransformerEncoder
from crslab.model.crs.hycorec.decoder import TransformerDecoderKG
class HyCoRecModel(BaseModel):
"""
Attributes:
vocab_size: A integer indicating the vocabulary size.
pad_token_idx: A integer indicating the id of padding token.
start_token_idx: A integer indicating the id of start token.
end_token_idx: A integer indicating the id of end token.
token_emb_dim: A integer indicating the dimension of token embedding layer.
pretrain_embedding: A string indicating the path of pretrained embedding.
n_entity: A integer indicating the number of entities.
n_relation: A integer indicating the number of relation in KG.
num_bases: A integer indicating the number of bases.
kg_emb_dim: A integer indicating the dimension of kg embedding.
user_emb_dim: A integer indicating the dimension of user embedding.
n_heads: A integer indicating the number of heads.
n_layers: A integer indicating the number of layer.
ffn_size: A integer indicating the size of ffn hidden.
dropout: A float indicating the dropout rate.
attention_dropout: A integer indicating the dropout rate of attention layer.
relu_dropout: A integer indicating the dropout rate of relu layer.
learn_positional_embeddings: A boolean indicating if we learn the positional embedding.
embeddings_scale: A boolean indicating if we use the embeddings scale.
reduction: A boolean indicating if we use the reduction.
n_positions: A integer indicating the number of position.
longest_label: A integer indicating the longest length for response generation.
user_proj_dim: A integer indicating dim to project for user embedding.
"""
def __init__(self, opt, device, vocab, side_data):
"""
Args:
opt (dict): A dictionary record the hyper parameters.
device (torch.device): A variable indicating which device to place the data and model.
vocab (dict): A dictionary record the vocabulary information.
side_data (dict): A dictionary record the side data.
"""
self.device = device
self.gpu = opt.get("gpu", -1)
self.dataset = opt.get("dataset", None)
self.llm = opt.get("llm", "chatgpt-4o")
assert self.dataset in ['HReDial', 'HTGReDial', 'DuRecDial', 'OpenDialKG', 'ReDial', 'TGReDial']
# vocab
self.pad_token_idx = vocab['tok2ind']['__pad__']
self.start_token_idx = vocab['tok2ind']['__start__']
self.end_token_idx = vocab['tok2ind']['__end__']
self.vocab_size = vocab['vocab_size']
self.token_emb_dim = opt.get('token_emb_dim', 300)
self.pretrain_embedding = side_data.get('embedding', None)
self.token2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "token2id.json"), "r", encoding="utf-8"))
self.entity2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "entity2id.json"), "r", encoding="utf-8"))
# kg
self.n_entity = vocab['n_entity']
self.entity_kg = side_data['entity_kg']
self.n_relation = self.entity_kg['n_relation']
self.edge_idx, self.edge_type = edge_to_pyg_format(self.entity_kg['edge'], 'RGCN')
self.edge_idx = self.edge_idx.to(device)
self.edge_type = self.edge_type.to(device)
self.num_bases = opt.get('num_bases', 8)
self.kg_emb_dim = opt.get('kg_emb_dim', 300)
self.user_emb_dim = self.kg_emb_dim
# transformer
self.n_heads = opt.get('n_heads', 2)
self.n_layers = opt.get('n_layers', 2)
self.ffn_size = opt.get('ffn_size', 300)
self.dropout = opt.get('dropout', 0.1)
self.attention_dropout = opt.get('attention_dropout', 0.0)
self.relu_dropout = opt.get('relu_dropout', 0.1)
self.embeddings_scale = opt.get('embedding_scale', True)
self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False)
self.reduction = opt.get('reduction', False)
self.n_positions = opt.get('n_positions', 1024)
self.longest_label = opt.get('longest_label', 30)
self.user_proj_dim = opt.get('user_proj_dim', 512)
# pooling
self.pooling = opt.get('pooling', None)
assert self.pooling == 'Attn' or self.pooling == 'Mean'
# MHA
self.mha_n_heads = opt.get('mha_n_heads', 4)
self.extension_strategy = opt.get('extension_strategy', None)
self.pretrain = opt.get('pretrain', False)
self.pretrain_data = None
self.pretrain_epoch = opt.get('pretrain_epoch', 9999)
super(HyCoRecModel, self).__init__(opt, device)
return
# 构建模型
def build_model(self, *args, **kwargs):
if self.pretrain:
pretrain_file = os.path.join('pretrain', self.dataset, str(self.pretrain_epoch) + '-epoch.pth')
self.pretrain_data = torch.load(pretrain_file, map_location=torch.device('cuda:' + str(self.gpu[0])))
logger.info(f"[Load Pretrain Weights from {pretrain_file}]")
# self._build_hredial_copy_mask()
self._build_adjacent_matrix()
# self._build_hllm_data()
self._build_embedding()
self._build_kg_layer()
self._build_recommendation_layer()
self._build_conversation_layer()
# 构建 mask
def _build_hredial_copy_mask(self):
token_filename = os.path.join(DATASET_PATH, "hredial", "nltk", "token2id.json")
token_file = open(token_filename, 'r', encoding="utf-8")
token2id = json.load(token_file)
id2token = {token2id[token]: token for token in token2id}
self.hredial_copy_mask = list()
for i in range(len(id2token)):
token = id2token[i]
if token[0] == '@':
self.hredial_copy_mask.append(True)
else:
self.hredial_copy_mask.append(False)
self.hredial_copy_mask = torch.as_tensor(self.hredial_copy_mask).to(self.device)
return
def _build_hllm_data(self):
self.hllm_data_table = {
"train": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_train_data.pkl"), "rb")),
"valid": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_valid_data.pkl"), "rb")),
"test": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_test_data.pkl"), "rb")),
}
return
# 构建关联矩阵
def _build_adjacent_matrix(self):
entity2id = self.entity2id
token2id = self.token2id
item_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "item_edger.pkl"), "rb"))
entity_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "entity_edger.pkl"), "rb"))
word_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "word_edger.pkl"), "rb"))
item_adj = {}
for item_a in item_edger:
item_list = item_edger[item_a]
if item_a not in entity2id:
continue
item_a = entity2id[item_a]
item_list = []
for item in item_list:
if item not in entity2id:
continue
item_list.append(entity2id[item])
item_adj[item_a] = item_list
self.item_adj = item_adj
entity_adj = {}
for entity_a in entity_edger:
entity_list = entity_edger[entity_a]
if entity_a not in entity2id:
continue
entity_a = entity2id[entity_a]
entity_list = []
for entity in entity_list:
if entity not in entity2id:
continue
entity_list.append(entity2id[entity])
entity_adj[entity_a] = entity_list
self.entity_adj = entity_adj
word_adj = {}
for word_a in word_edger:
word_list = word_edger[word_a]
if word_a not in token2id:
continue
word_a = token2id[word_a]
word_list = []
for word in word_list:
if word not in token2id:
continue
word_list.append(token2id[word])
word_adj[word_a] = word_list
self.word_adj = word_adj
logger.info(f"[Adjacent Matrix built.]")
return
# 构建编码层
def _build_embedding(self):
if self.pretrain_embedding is not None:
self.token_embedding = nn.Embedding.from_pretrained(
torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False,
padding_idx=self.pad_token_idx)
else:
self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)
nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)
self.entity_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0)
nn.init.normal_(self.entity_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
nn.init.constant_(self.entity_embedding.weight[0], 0)
self.word_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0)
nn.init.normal_(self.word_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
nn.init.constant_(self.word_embedding.weight[0], 0)
logger.debug('[Build embedding]')
return
# 构建超图编码层
def _build_kg_layer(self):
# graph encoder
self.item_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
self.entity_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
self.word_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
if self.pretrain:
self.item_encoder.load_state_dict(self.pretrain_data['encoder'])
# hypergraph convolution
self.hyper_conv_item = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
self.hyper_conv_entity = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
self.hyper_conv_word = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
# attention type
self.item_attn = MHItemAttention(self.kg_emb_dim, self.mha_n_heads)
# pooling
if self.pooling == 'Attn':
self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)
self.kg_attn_his = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)
logger.debug('[Build kg layer]')
return
# 构建推荐模块
def _build_recommendation_layer(self):
self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity)
self.rec_loss = nn.CrossEntropyLoss()
logger.debug('[Build recommendation layer]')
return
# 构建对话模块
def _build_conversation_layer(self):
self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))
self.entity_to_token = nn.Linear(self.kg_emb_dim, self.token_emb_dim, bias=True)
self.related_encoder = TransformerEncoder(
self.n_heads,
self.n_layers,
self.token_emb_dim,
self.ffn_size,
self.vocab_size,
self.token_embedding,
self.dropout,
self.attention_dropout,
self.relu_dropout,
self.pad_token_idx,
self.learn_positional_embeddings,
self.embeddings_scale,
self.reduction,
self.n_positions
)
self.context_encoder = TransformerEncoder(
self.n_heads,
self.n_layers,
self.token_emb_dim,
self.ffn_size,
self.vocab_size,
self.token_embedding,
self.dropout,
self.attention_dropout,
self.relu_dropout,
self.pad_token_idx,
self.learn_positional_embeddings,
self.embeddings_scale,
self.reduction,
self.n_positions
)
self.decoder = TransformerDecoderKG(
self.n_heads,
self.n_layers,
self.token_emb_dim,
self.ffn_size,
self.vocab_size,
self.token_embedding,
self.dropout,
self.attention_dropout,
self.relu_dropout,
self.embeddings_scale,
self.learn_positional_embeddings,
self.pad_token_idx,
self.n_positions
)
self.user_proj_1 = nn.Linear(self.user_emb_dim, self.user_proj_dim)
self.user_proj_2 = nn.Linear(self.user_proj_dim, self.vocab_size)
self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)
self.copy_proj_1 = nn.Linear(2 * self.token_emb_dim, self.token_emb_dim)
self.copy_proj_2 = nn.Linear(self.token_emb_dim, self.vocab_size)
logger.debug('[Build conversation layer]')
return
# 获取超图
def _get_hypergraph(self, related, adj):
related_items_set = set()
for related_items in related:
related_items_set.add(related_items)
session_related_items = list(related_items_set)
hypergraph_nodes, hypergraph_edges, hyper_edge_counter = list(), list(), 0
for item in session_related_items:
hypergraph_nodes.append(item)
hypergraph_edges.append(hyper_edge_counter)
neighbors = list(adj.get(item, []))
hypergraph_nodes += neighbors
hypergraph_edges += [hyper_edge_counter] * len(neighbors)
hyper_edge_counter += 1
hyper_edge_index = torch.tensor([hypergraph_nodes, hypergraph_edges], device=self.device)
return list(set(hypergraph_nodes)), hyper_edge_index
# 获取聚合
def _get_embedding(self, hypergraph_items, embedding, tot_sub, adj):
knowledge_embedding_list = []
for item in hypergraph_items:
sub_graph = [item] + list(adj.get(item, []))
sub_graph = [tot_sub[item] for item in sub_graph]
sub_graph_embedding = embedding[sub_graph]
sub_graph_embedding = torch.mean(sub_graph_embedding, dim=0)
knowledge_embedding_list.append(sub_graph_embedding)
res_embedding = torch.zeros(1, self.kg_emb_dim).to(self.device)
if len(knowledge_embedding_list) > 0:
res_embedding = torch.stack(knowledge_embedding_list, dim=0)
return res_embedding
@staticmethod
def flatten(inputs):
outputs = set()
for li in inputs:
for i in li:
outputs.add(i)
return list(outputs)
# 注意力融合特征向量
def _attention_and_gating(self, session_embedding, knowledge_embedding, conceptnet_embedding, context_embedding):
related_embedding = torch.cat((session_embedding, knowledge_embedding, conceptnet_embedding), dim=0)
if context_embedding is None:
if self.pooling == 'Attn':
user_repr = self.kg_attn_his(related_embedding)
else:
assert self.pooling == 'Mean'
user_repr = torch.mean(related_embedding, dim=0)
elif self.pooling == 'Attn':
attentive_related_embedding = self.item_attn(related_embedding, context_embedding)
user_repr = self.kg_attn_his(attentive_related_embedding)
user_repr = torch.unsqueeze(user_repr, dim=0)
user_repr = torch.cat((context_embedding, user_repr), dim=0)
user_repr = self.kg_attn(user_repr)
else:
assert self.pooling == 'Mean'
attentive_related_embedding = self.item_attn(related_embedding, context_embedding)
user_repr = torch.mean(attentive_related_embedding, dim=0)
user_repr = torch.unsqueeze(user_repr, dim=0)
user_repr = torch.cat((context_embedding, user_repr), dim=0)
user_repr = torch.mean(user_repr, dim=0)
return user_repr
def _get_hllm_embedding(self, tot_embedding, hllm_hyper_graph, adj, conv):
hllm_hyper_edge_A = []
hllm_hyper_edge_B = []
for idx, hyper_edge in enumerate(hllm_hyper_graph):
hllm_hyper_edge_A += [item for item in hyper_edge]
hllm_hyper_edge_B += [idx] * len(hyper_edge)
hllm_items = list(set(hllm_hyper_edge_A))
sub_item2id = {item:idx for idx, item in enumerate(hllm_items)}
sub_embedding = tot_embedding[hllm_items]
hllm_hyper_edge = [[sub_item2id[item] for item in hllm_hyper_edge_A], hllm_hyper_edge_B]
hllm_hyper_edge = torch.LongTensor(hllm_hyper_edge).to(self.device)
embedding = conv(sub_embedding, hllm_hyper_edge)
return embedding
def encode_user_repr(self, related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
# COLD START
# if len(related_items) == 0 or len(related_words) == 0:
# if len(related_entities) == 0:
# user_repr = torch.zeros(self.user_emb_dim, device=self.device)
# elif self.pooling == 'Attn':
# user_repr = tot_entity_embedding[related_entities]
# user_repr = self.kg_attn(user_repr)
# else:
# assert self.pooling == 'Mean'
# user_repr = tot_entity_embedding[related_entities]
# user_repr = torch.mean(user_repr, dim=0)
# return user_repr
# 获取超图后的数据
item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(related_items) > 0:
items, item_hyper_edge_index = self._get_hypergraph(related_items, self.item_adj)
sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(tot_item_embedding, items, item_hyper_edge_index, self.item_adj)
raw_item_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index)
item_embedding = raw_item_embedding
# item_embedding = self._get_embedding(items, raw_item_embedding, item_tot2sub, self.item_adj)
entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(related_entities) > 0:
entities, entity_hyper_edge_index = self._get_hypergraph(related_entities, self.entity_adj)
sub_entity_embedding, sub_entity_edge_index, entity_tot2sub = self._before_hyperconv(tot_entity_embedding, entities, entity_hyper_edge_index, self.entity_adj)
raw_entity_embedding = self.hyper_conv_entity(sub_entity_embedding, sub_entity_edge_index)
entity_embedding = raw_entity_embedding
# entity_embedding = self._get_embedding(entities, raw_entity_embedding, entity_tot2sub, self.entity_adj)
word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(related_words) > 0:
owrds, word_hyper_edge_index = self._get_hypergraph(related_words, self.word_adj)
sub_word_embedding, sub_word_edge_index, word_tot2sub = self._before_hyperconv(tot_word_embedding, owrds, word_hyper_edge_index, self.word_adj)
raw_word_embedding = self.hyper_conv_word(sub_word_embedding, sub_word_edge_index)
word_embedding = raw_word_embedding
# word_embedding = self._get_embedding(owrds, raw_word_embedding, word_tot2sub, self.word_adj)
# 注意力机制
if len(related_entities) == 0:
user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, None)
else:
context_embedding = tot_entity_embedding[related_entities]
user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, context_embedding)
return user_repr
def process_hllm(self, hllm_data, id_dict):
res_data = []
for raw_hyper_grapth in hllm_data:
if not isinstance(raw_hyper_grapth, list):
continue
temp_hyper_grapth = []
for meta_data in raw_hyper_grapth:
if not isinstance(meta_data, int):
continue
if meta_data not in id_dict:
continue
temp_hyper_grapth.append(id_dict[meta_data])
res_data.append(temp_hyper_grapth)
return res_data
# 获取用户编码
def encode_user(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
user_repr_list = []
for related_items, related_entities, related_words in zip(batch_related_items, batch_related_entities, batch_related_words):
user_repr = self.encode_user_repr(related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding)
user_repr_list.append(user_repr)
user_embedding = torch.stack(user_repr_list, dim=0)
# print("user_embedding.shape", user_embedding.shape) # [6, 128]
return user_embedding
# 推荐模块
def recommend(self, batch, mode):
# 获取数据
conv_id = batch['conv_id']
related_item = batch['related_item']
related_entity = batch['related_entity']
related_word = batch['related_word']
item = batch['item']
item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type)
# 获取用户编码
# start = perf_counter()
user_embedding = self.encode_user(
related_item,
related_entity,
related_word,
item_embedding,
entity_embedding,
token_embedding,
) # (batch_size, emb_dim)
# print(f"{perf_counter() - start:.2f}")
# 计算各实体得分
scores = F.linear(user_embedding, entity_embedding, self.rec_bias.bias) # (batch_size, n_entity)
loss = self.rec_loss(scores, item)
return loss, scores
def _starts(self, batch_size):
"""Return bsz start tokens."""
return self.START.detach().expand(batch_size, 1)
def freeze_parameters(self):
freeze_models = [
self.entity_embedding,
self.token_embedding,
self.item_encoder,
self.entity_encoder,
self.word_encoder,
self.hyper_conv_item,
self.hyper_conv_entity,
self.hyper_conv_word,
self.item_attn,
self.rec_bias
]
if self.pooling == "Attn":
freeze_models.append(self.kg_attn)
freeze_models.append(self.kg_attn_his)
for model in freeze_models:
for p in model.parameters():
p.requires_grad = False
def _before_hyperconv(self, embeddings:torch.FloatTensor, hypergraph_items:List[int], edge_index:torch.LongTensor, adj):
sub_items = []
edge_index = edge_index.cpu().numpy()
for item in hypergraph_items:
sub_items += [item] + list(adj.get(item, []))
sub_items = list(set(sub_items))
tot2sub = {tot:sub for sub, tot in enumerate(sub_items)}
sub_embeddings = embeddings[sub_items]
edge_index = [[tot2sub[v] for v in edge_index[0]], list(edge_index[1])]
sub_edge_index = torch.tensor(edge_index).long()
sub_edge_index = sub_edge_index.to(self.device)
return sub_embeddings, sub_edge_index, tot2sub
# 获取超图后数据
def encode_session(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
"""
Return: session_repr (batch_size, batch_seq_len, token_emb_dim), mask (batch_size, batch_seq_len)
"""
session_repr_list = []
for session_related_items, session_related_entities, session_related_words in zip(batch_related_items, batch_related_entities, batch_related_words):
# COLD START
# if len(session_related_items) == 0 or len(session_related_words) == 0:
# if len(session_related_entities) == 0:
# session_repr_list.append(None)
# else:
# session_repr = tot_entity_embedding[session_related_entities]
# session_repr_list.append(session_repr)
# continue
# 获取超图后的数据
item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(session_related_items) > 0:
items, item_hyper_edge_index = self._get_hypergraph(session_related_items, self.item_adj)
sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(tot_item_embedding, items, item_hyper_edge_index, self.item_adj)
raw_item_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index)
item_embedding = raw_item_embedding
# item_embedding = self._get_embedding(items, raw_item_embedding, item_tot2sub, self.item_adj)
entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(session_related_entities) > 0:
entities, entity_hyper_edge_index = self._get_hypergraph(session_related_entities, self.entity_adj)
sub_entity_embedding, sub_entity_edge_index, entity_tot2sub = self._before_hyperconv(tot_entity_embedding, entities, entity_hyper_edge_index, self.entity_adj)
raw_entity_embedding = self.hyper_conv_entity(sub_entity_embedding, sub_entity_edge_index)
entity_embedding = raw_entity_embedding
# entity_embedding = self._get_embedding(entities, raw_entity_embedding, entity_tot2sub, self.entity_adj)
word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
if len(session_related_words) > 0:
owrds, word_hyper_edge_index = self._get_hypergraph(session_related_words, self.word_adj)
sub_word_embedding, sub_word_edge_index, word_tot2sub = self._before_hyperconv(tot_word_embedding, owrds, word_hyper_edge_index, self.word_adj)
raw_word_embedding = self.hyper_conv_word(sub_word_embedding, sub_word_edge_index)
word_embedding = raw_word_embedding
# word_embedding = self._get_embedding(owrds, raw_word_embedding, word_tot2sub, self.word_adj)
# 数据拼接
if len(session_related_entities) == 0:
session_repr = torch.cat((item_embedding, entity_embedding, word_embedding), dim=0)
session_repr_list.append(session_repr)
else:
context_embedding = tot_entity_embedding[session_related_entities]
session_repr = torch.cat((item_embedding, entity_embedding, context_embedding, word_embedding), dim=0)
session_repr_list.append(session_repr)
batch_seq_len = max([session_repr.size(0) for session_repr in session_repr_list if session_repr is not None])
mask_list = []
for i in range(len(session_repr_list)):
if session_repr_list[i] is None:
mask_list.append([False] * batch_seq_len)
zero_repr = torch.zeros((batch_seq_len, self.kg_emb_dim), device=self.device, dtype=torch.float)
session_repr_list[i] = zero_repr
else:
mask_list.append([False] * (batch_seq_len - session_repr_list[i].size(0)) + [True] * session_repr_list[i].size(0))
zero_repr = torch.zeros((batch_seq_len - session_repr_list[i].size(0), self.kg_emb_dim),
device=self.device, dtype=torch.float)
session_repr_list[i] = torch.cat((zero_repr, session_repr_list[i]), dim=0)
session_repr_embedding = torch.stack(session_repr_list, dim=0)
session_repr_embedding = self.entity_to_token(session_repr_embedding)
# print("session_repr_embedding.shape", session_repr_embedding.shape) # [6, 7, 300]
return session_repr_embedding, torch.tensor(mask_list, device=self.device, dtype=torch.bool)
# 生成对话
def decode_forced(self, related_encoder_state, context_encoder_state, session_state, user_embedding, resp):
bsz = resp.size(0)
seqlen = resp.size(1)
inputs = resp.narrow(1, 0, seqlen - 1)
inputs = torch.cat([self._starts(bsz), inputs], 1)
latent, _ = self.decoder(inputs, related_encoder_state, context_encoder_state, session_state)
token_logits = F.linear(latent, self.token_embedding.weight)
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
user_latent = self.entity_to_token(user_embedding)
user_latent = user_latent.unsqueeze(1).expand(-1, seqlen, -1)
copy_latent = torch.cat((user_latent, latent), dim=-1)
copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent)))
if self.dataset == 'HReDial':
copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial
sum_logits = token_logits + user_logits + copy_logits
_, preds = sum_logits.max(dim=-1)
return sum_logits, preds
# 生成对话 - test
def decode_greedy(self, related_encoder_state, context_encoder_state, session_state, user_embedding):
bsz = context_encoder_state[0].shape[0]
xs = self._starts(bsz)
incr_state = None
logits = []
for i in range(self.longest_label):
scores, incr_state = self.decoder(xs, related_encoder_state, context_encoder_state, session_state, incr_state) # incr_state is always None
scores = scores[:, -1:, :]
token_logits = F.linear(scores, self.token_embedding.weight)
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
user_latent = self.entity_to_token(user_embedding)
user_latent = user_latent.unsqueeze(1).expand(-1, 1, -1)
copy_latent = torch.cat((user_latent, scores), dim=-1)
copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent)))
if self.dataset == 'HReDial':
copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial
sum_logits = token_logits + user_logits + copy_logits
probs, preds = sum_logits.max(dim=-1)
logits.append(scores)
xs = torch.cat([xs, preds], dim=1)
# check if everyone has generated an end token
all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz
if all_finished:
break
logits = torch.cat(logits, 1)
return logits, xs
# 对话模块训练
def converse(self, batch, mode):
# 获取数据
conv_id = batch['conv_id']
related_item = batch['related_item']
related_entity = batch['related_entity']
related_word = batch['related_word']
response = batch['response']
related_tokens = batch['related_tokens']
context_tokens = batch['context_tokens']
item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type)
# 获取对话编码
session_state = self.encode_session(
related_item,
related_entity,
related_word,
item_embedding,
entity_embedding,
token_embedding,
)
# 获取用户编码
# start = perf_counter()
user_embedding = self.encode_user(
related_item,
related_entity,
related_word,
item_embedding,
entity_embedding,
token_embedding,
) # (batch_size, emb_dim)
# 获取 X_c、X_h
related_encoder_state = self.related_encoder(related_tokens)
context_encoder_state = self.context_encoder(context_tokens)
# 对话生成
if mode != 'test':
self.longest_label = max(self.longest_label, response.shape[1])
logits, preds = self.decode_forced(related_encoder_state, context_encoder_state, session_state, user_embedding, response)
logits = logits.view(-1, logits.shape[-1])
labels = response.view(-1)
return self.conv_loss(logits, labels), preds
else:
_, preds = self.decode_greedy(related_encoder_state, context_encoder_state, session_state, user_embedding)
return preds
# 推荐模块和对话模块分开训练
def forward(self, batch, mode, stage):
if len(self.gpu) >= 2:
self.edge_idx = self.edge_idx.cuda(torch.cuda.current_device())
self.edge_type = self.edge_type.cuda(torch.cuda.current_device())
if stage == "conv":
return self.converse(batch, mode)
if stage == "rec":
# start = perf_counter()
res = self.recommend(batch, mode)
# print(f"{perf_counter() - start:.2f}")
return res

View File

@@ -0,0 +1,64 @@
# -*- encoding: utf-8 -*-
# @Time : 2021/1/6
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2021/1/7
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.download import DownloadableFile
"""Download links of pretrain models.
Now we provide the following models:
- `BERT`_: zh, en
- `GPT2`_: zh, en
.. _BERT:
https://www.aclweb.org/anthology/N19-1423/
.. _GPT2:
https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
"""
resources = {
'bert': {
'zh': {
'version': '0.1',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1',
'bert_zh.zip',
'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3'
)
},
'en': {
'version': '0.1',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1',
'bert_en.zip',
'61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3'
)
},
},
'gpt2': {
'zh': {
'version': '0.1',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1',
'gpt2_zh.zip',
'5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43'
)
},
'en': {
'version': '0.1',
'file': DownloadableFile(
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1',
'gpt2_en.zip',
'518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e'
)
}
}
}

View File

View File

@@ -0,0 +1,37 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/11/26
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2020/11/16
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
import torch
def edge_to_pyg_format(edge, type='RGCN'):
if type == 'RGCN':
edge_sets = torch.as_tensor(edge, dtype=torch.long)
edge_idx = edge_sets[:, :2].t()
edge_type = edge_sets[:, 2]
return edge_idx, edge_type
elif type == 'GCN':
edge_set = [[co[0] for co in edge], [co[1] for co in edge]]
return torch.as_tensor(edge_set, dtype=torch.long)
else:
raise NotImplementedError('type {} has not been implemented', type)
def sort_for_packed_sequence(lengths: torch.Tensor):
"""
:param lengths: 1D array of lengths
:return: sorted_lengths (lengths in descending order), sorted_idx (indices to sort), rev_idx (indices to retrieve original order)
"""
sorted_idx = torch.argsort(lengths, descending=True) # idx to sort by length
rev_idx = torch.argsort(sorted_idx) # idx to retrieve original order
sorted_lengths = lengths[sorted_idx]
return sorted_lengths, sorted_idx, rev_idx

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttentionBatch(nn.Module):
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
super(SelfAttentionBatch, self).__init__()
self.dim = dim
self.da = da
self.alpha = alpha
self.dropout = dropout
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
nn.init.xavier_uniform_(self.b.data, gain=1.414)
def forward(self, h):
# h: (N, dim)
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)
attention = F.softmax(e, dim=0) # (N)
return torch.matmul(attention, h) # (dim)
class SelfAttentionSeq(nn.Module):
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
super(SelfAttentionSeq, self).__init__()
self.dim = dim
self.da = da
self.alpha = alpha
self.dropout = dropout
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
nn.init.xavier_uniform_(self.b.data, gain=1.414)
def forward(self, h, mask=None, return_logits=False):
"""
For the padding tokens, its corresponding mask is True
if mask==[1, 1, 1, ...]
"""
# h: (batch, seq_len, dim), mask: (batch, seq_len)
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) # (batch, seq_len, 1)
if mask is not None:
full_mask = -1e30 * mask.float()
batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1) # for all padding one, the mask=0
mask = full_mask * batch_mask
e += mask.unsqueeze(-1)
attention = F.softmax(e, dim=1) # (batch, seq_len, 1)
# (batch, dim)
if return_logits:
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1)
else:
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1)

View File

@@ -0,0 +1,471 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Near infinity, useful as a large penalty for scoring when inf is bad."""
NEAR_INF = 1e20
NEAR_INF_FP16 = 65504
def neginf(dtype):
"""Returns a representable finite number near -inf for a dtype."""
if dtype is torch.float16:
return -NEAR_INF_FP16
else:
return -NEAR_INF
def _create_selfattn_mask(x):
# figure out how many timestamps we need
bsz = x.size(0)
time = x.size(1)
# make sure that we don't look into the future
mask = torch.tril(x.new(time, time).fill_(1))
# broadcast across batch
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
return mask
def create_position_codes(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * j / dim) for j in range(dim // 2)]
for pos in range(n_pos)
])
out.data[:, 0::2] = torch.as_tensor(np.sin(position_enc))
out.data[:, 1::2] = torch.as_tensor(np.cos(position_enc))
out.detach_()
out.requires_grad = False
def _normalize(tensor, norm_layer):
"""Broadcast layer norm"""
size = tensor.size()
return norm_layer(tensor.view(-1, size[-1])).view(size)
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, dropout=.0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
self.attn_dropout = nn.Dropout(p=dropout) # --attention-dropout
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
# TODO: merge for the initialization step
nn.init.xavier_normal_(self.q_lin.weight)
nn.init.xavier_normal_(self.k_lin.weight)
nn.init.xavier_normal_(self.v_lin.weight)
# and set biases to 0
self.out_lin = nn.Linear(dim, dim)
nn.init.xavier_normal_(self.out_lin.weight)
def forward(self, query, key=None, value=None, mask=None):
# Input is [B, query_len, dim]
# Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn)
batch_size, query_len, dim = query.size()
assert dim == self.dim, \
f'Dimensions do not match: {dim} query vs {self.dim} configured'
assert mask is not None, 'Mask is None, please specify a mask'
n_heads = self.n_heads
dim_per_head = dim // n_heads
scale = math.sqrt(dim_per_head)
def prepare_head(tensor):
# input is [batch_size, seq_len, n_heads * dim_per_head]
# output is [batch_size * n_heads, seq_len, dim_per_head]
bsz, seq_len, _ = tensor.size()
tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head)
tensor = tensor.transpose(1, 2).contiguous().view(
batch_size * n_heads,
seq_len,
dim_per_head
)
return tensor
# q, k, v are the transformed values
if key is None and value is None:
# self attention
key = value = query
elif value is None:
# key and value are the same, but query differs
# self attention
value = key
_, key_len, dim = key.size()
q = prepare_head(self.q_lin(query))
k = prepare_head(self.k_lin(key))
v = prepare_head(self.v_lin(value))
dot_prod = q.div_(scale).bmm(k.transpose(1, 2))
# [B * n_heads, query_len, key_len]
attn_mask = (
(mask == 0)
.view(batch_size, 1, -1, key_len)
.repeat(1, n_heads, 1, 1)
.expand(batch_size, n_heads, query_len, key_len)
.view(batch_size * n_heads, query_len, key_len)
)
assert attn_mask.shape == dot_prod.shape
dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))
attn_weights = F.softmax(dot_prod, dim=-1).type_as(query)
attn_weights = self.attn_dropout(attn_weights) # --attention-dropout
attentioned = attn_weights.bmm(v)
attentioned = (
attentioned.type_as(query)
.view(batch_size, n_heads, query_len, dim_per_head)
.transpose(1, 2).contiguous()
.view(batch_size, query_len, dim)
)
out = self.out_lin(attentioned)
return out
class TransformerFFN(nn.Module):
def __init__(self, dim, dim_hidden, relu_dropout=.0):
super(TransformerFFN, self).__init__()
self.relu_dropout = nn.Dropout(p=relu_dropout)
self.lin1 = nn.Linear(dim, dim_hidden)
self.lin2 = nn.Linear(dim_hidden, dim)
nn.init.xavier_uniform_(self.lin1.weight)
nn.init.xavier_uniform_(self.lin2.weight)
# TODO: initialize biases to 0
def forward(self, x):
x = F.relu(self.lin1(x))
x = self.relu_dropout(x) # --relu-dropout
x = self.lin2(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
):
super().__init__()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.attention = MultiHeadAttention(
n_heads, embedding_size,
dropout=attention_dropout, # --attention-dropout
)
self.norm1 = nn.LayerNorm(embedding_size)
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
self.norm2 = nn.LayerNorm(embedding_size)
self.dropout = nn.Dropout(p=dropout)
def forward(self, tensor, mask):
tensor = tensor + self.dropout(self.attention(tensor, mask=mask))
tensor = _normalize(tensor, self.norm1)
tensor = tensor + self.dropout(self.ffn(tensor))
tensor = _normalize(tensor, self.norm2)
tensor *= mask.unsqueeze(-1).type_as(tensor)
return tensor
class TransformerEncoder(nn.Module):
"""
Transformer encoder module.
:param int n_heads: the number of multihead attention heads.
:param int n_layers: number of transformer layers.
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
:param int ffn_size: the size of the hidden layer in the FFN
:param embedding: an embedding matrix for the bottom layer of the transformer.
If none, one is created for this encoder.
:param float dropout: Dropout used around embeddings and before layer
layer normalizations. This is used in Vaswani 2017 and works well on
large datasets.
:param float attention_dropout: Dropout performed after the multhead attention
softmax. This is not used in Vaswani 2017.
:param float relu_dropout: Dropout used after the ReLU in the FFN. Not used
in Vaswani 2017, but used in Tensor2Tensor.
:param int padding_idx: Reserved padding index in the embeddings matrix.
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
used. If on, position embeddings are learned from scratch.
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
Found useful in fairseq.
:param bool reduction: If true, returns the mean vector for the entire encoding
sequence.
:param int n_positions: Size of the position embeddings matrix.
"""
def __init__(
self,
n_heads,
n_layers,
embedding_size,
ffn_size,
vocabulary_size,
embedding=None,
dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
padding_idx=0,
learn_positional_embeddings=False,
embeddings_scale=False,
reduction=True,
n_positions=1024
):
super(TransformerEncoder, self).__init__()
self.embedding_size = embedding_size
self.ffn_size = ffn_size
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = embedding_size
self.embeddings_scale = embeddings_scale
self.reduction = reduction
self.padding_idx = padding_idx
# this is --dropout, not --relu-dropout or --attention-dropout
self.dropout = nn.Dropout(dropout)
self.out_dim = embedding_size
assert embedding_size % n_heads == 0, \
'Transformer embedding size must be a multiple of n_heads'
# check input formats:
if embedding is not None:
assert (
embedding_size is None or embedding_size == embedding.weight.shape[1]
), "Embedding dim must match the embedding size."
if embedding is not None:
self.embeddings = embedding
else:
assert False
assert padding_idx is not None
self.embeddings = nn.Embedding(
vocabulary_size, embedding_size, padding_idx=padding_idx
)
nn.init.normal_(self.embeddings.weight, 0, embedding_size ** -0.5)
# create the positional embeddings
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
if not learn_positional_embeddings:
create_position_codes(
n_positions, embedding_size, out=self.position_embeddings.weight
)
else:
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
# build the model
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(TransformerEncoderLayer(
n_heads, embedding_size, ffn_size,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
dropout=dropout,
))
def forward(self, input):
"""
input data is a FloatTensor of shape [batch, seq_len, dim]
mask is a ByteTensor of shape [batch, seq_len], filled with 1 when
inside the sequence and 0 outside.
"""
mask = input != self.padding_idx
positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0)
tensor = self.embeddings(input)
if self.embeddings_scale:
tensor = tensor * np.sqrt(self.dim)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
# --dropout on the embeddings
tensor = self.dropout(tensor)
tensor *= mask.unsqueeze(-1).type_as(tensor)
for i in range(self.n_layers):
tensor = self.layers[i](tensor, mask)
if self.reduction:
divisor = mask.type_as(tensor).sum(dim=1).unsqueeze(-1).clamp(min=1e-7)
output = tensor.sum(dim=1) / divisor
return output
else:
output = tensor
return output, mask
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
):
super().__init__()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.dropout = nn.Dropout(p=dropout)
self.self_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm1 = nn.LayerNorm(embedding_size)
self.encoder_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2 = nn.LayerNorm(embedding_size)
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
self.norm3 = nn.LayerNorm(embedding_size)
def forward(self, x, encoder_output, encoder_mask):
decoder_mask = self._create_selfattn_mask(x)
# first self attn
residual = x
# don't peak into the future!
x = self.self_attention(query=x, mask=decoder_mask)
x = self.dropout(x) # --dropout
x = x + residual
x = _normalize(x, self.norm1)
residual = x
x = self.encoder_attention(
query=x,
key=encoder_output,
value=encoder_output,
mask=encoder_mask
)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm2)
# finally the ffn
residual = x
x = self.ffn(x)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm3)
return x
def _create_selfattn_mask(self, x):
# figure out how many timestamps we need
bsz = x.size(0)
time = x.size(1)
# make sure that we don't look into the future
mask = torch.tril(x.new(time, time).fill_(1))
# broadcast across batch
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
return mask
class TransformerDecoder(nn.Module):
"""
Transformer Decoder layer.
:param int n_heads: the number of multihead attention heads.
:param int n_layers: number of transformer layers.
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
:param int ffn_size: the size of the hidden layer in the FFN
:param embedding: an embedding matrix for the bottom layer of the transformer.
If none, one is created for this encoder.
:param float dropout: Dropout used around embeddings and before layer
layer normalizations. This is used in Vaswani 2017 and works well on
large datasets.
:param float attention_dropout: Dropout performed after the multhead attention
softmax. This is not used in Vaswani 2017.
:param int padding_idx: Reserved padding index in the embeddings matrix.
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
used. If on, position embeddings are learned from scratch.
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
Found useful in fairseq.
:param int n_positions: Size of the position embeddings matrix.
"""
def __init__(
self,
n_heads,
n_layers,
embedding_size,
ffn_size,
vocabulary_size,
embedding=None,
dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
embeddings_scale=True,
learn_positional_embeddings=False,
padding_idx=None,
n_positions=1024,
):
super().__init__()
self.embedding_size = embedding_size
self.ffn_size = ffn_size
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = embedding_size
self.embeddings_scale = embeddings_scale
self.dropout = nn.Dropout(p=dropout) # --dropout
self.out_dim = embedding_size
assert embedding_size % n_heads == 0, \
'Transformer embedding size must be a multiple of n_heads'
self.embeddings = embedding
# create the positional embeddings
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
if not learn_positional_embeddings:
create_position_codes(
n_positions, embedding_size, out=self.position_embeddings.weight
)
else:
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
# build the model
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(TransformerDecoderLayer(
n_heads, embedding_size, ffn_size,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
dropout=dropout,
))
def forward(self, input, encoder_state, incr_state=None):
encoder_output, encoder_mask = encoder_state
seq_len = input.shape[1]
positions = input.new_empty(seq_len).long()
positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len)
tensor = self.embeddings(input)
if self.embeddings_scale:
tensor = tensor * np.sqrt(self.dim)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor = self.dropout(tensor) # --dropout
for layer in self.layers:
tensor = layer(tensor, encoder_output, encoder_mask)
return tensor, None

View File

@@ -0,0 +1 @@
from .quick_start import run_crslab

View File

@@ -0,0 +1,82 @@
# -*- encoding: utf-8 -*-
# @Time : 2021/1/8
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
# UPDATE
# @Time : 2021/1/9
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
from crslab.config import Config
from crslab.data import get_dataset, get_dataloader
from crslab.system import get_system
def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False,
interact=False, debug=False):
"""A fast running api, which includes the complete process of training and testing models on specified datasets.
Args:
config (Config or str): an instance of ``Config`` or path to the config file,
which should be in ``yaml`` format. You can use default config provided in the `Github repo`_,
or write it by yourself.
save_data (bool): whether to save data. Defaults to False.
restore_data (bool): whether to restore data. Defaults to False.
save_system (bool): whether to save system. Defaults to False.
restore_system (bool): whether to restore system. Defaults to False.
interact (bool): whether to interact with the system. Defaults to False.
debug (bool): whether to debug the system. Defaults to False.
.. _Github repo:
https://github.com/RUCAIBox/CRSLab
"""
# dataset & dataloader
if isinstance(config['tokenize'], str):
CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data)
side_data = CRS_dataset.side_data
vocab = CRS_dataset.vocab
train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab)
valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab)
test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab)
else:
tokenized_dataset = {}
train_dataloader = {}
valid_dataloader = {}
test_dataloader = {}
vocab = {}
side_data = {}
for task, tokenize in config['tokenize'].items():
if tokenize in tokenized_dataset:
dataset = tokenized_dataset[tokenize]
else:
dataset = get_dataset(config, tokenize, restore_data, save_data)
tokenized_dataset[tokenize] = dataset
train_data = dataset.train_data
valid_data = dataset.valid_data
test_data = dataset.test_data
train_review = dataset.train_review
valid_review = dataset.valid_review
test_review = dataset.test_review
side_data[task] = dataset.side_data
vocab[task] = dataset.vocab
review_id2entity = dataset.review_id2entity
train_dataloader[task] = get_dataloader(config, train_data, train_review, vocab[task], review_id2entity)
valid_dataloader[task] = get_dataloader(config, valid_data, valid_review, vocab[task], review_id2entity)
test_dataloader[task] = get_dataloader(config, test_data, test_review, vocab[task], review_id2entity)
# system
CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system,
interact, debug)
if interact:
CRS.interact()
else:
CRS.fit()
if save_system:
CRS.save_model()
return

View File

@@ -0,0 +1,36 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/29
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
# @Time : 2021/10/6
# @Author : Zhipeng Zhao
# @email : oran_official@outlook.com
from loguru import logger
from .hycorec import HyCoRecSystem
system_register_table = {
'HyCoRec': HyCoRecSystem,
}
def get_system(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
interact=False, debug=False):
"""
return the system class
"""
model_name = opt['model_name']
if model_name in system_register_table:
system = system_register_table[model_name](opt, train_dataloader, valid_dataloader, test_dataloader, vocab,
side_data, restore_system, interact, debug)
logger.info(f'[Build system {model_name}]')
return system
else:
raise NotImplementedError('The system with model [{}] in dataset [{}] has not been implemented'.
format(model_name, opt['dataset']))

View File

@@ -0,0 +1,290 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2021/1/9
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
# UPDATE:
# @Time : 2021/11/5
# @Author : Zhipeng Zhao
# @Email : oran_official@outlook.com
import os
from abc import ABC, abstractmethod
import numpy as np
import random
import torch
from loguru import logger
from torch import optim
from transformers import Adafactor
from crslab.config import SAVE_PATH
from crslab.evaluator import get_evaluator
from crslab.evaluator.metrics.base import AverageMetric
from crslab.model import get_model
from crslab.system.utils import lr_scheduler
from crslab.system.utils.functions import compute_grad_norm
optim_class = {}
optim_class.update({k: v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper()})
optim_class.update({'AdamW': optim.AdamW, 'Adafactor': Adafactor})
lr_scheduler_class = {k: v for k, v in lr_scheduler.__dict__.items() if not k.startswith('__') and k[0].isupper()}
transformers_tokenizer = ('bert', 'gpt2')
class BaseSystem(ABC):
"""Base class for all system"""
def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
interact=False, debug=False):
"""
Args:
opt (dict): Indicating the hyper parameters.
train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.
valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.
test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.
vocab (dict): Indicating the vocabulary.
side_data (dict): Indicating the side data.
restore_system (bool, optional): Indicating if we store system after training. Defaults to False.
interact (bool, optional): Indicating if we interact with system. Defaults to False.
debug (bool, optional): Indicating if we train in debug mode. Defaults to False.
"""
self.opt = opt
if opt["gpu"] == [-1]:
self.device = torch.device('cpu')
elif len(opt["gpu"]) == 1:
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
# seed
if 'seed' in opt:
seed = int(opt['seed'])
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logger.info(f'[Set seed] {seed}')
# data
if debug:
self.train_dataloader = valid_dataloader
self.valid_dataloader = valid_dataloader
self.test_dataloader = test_dataloader
else:
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.test_dataloader = test_dataloader
self.vocab = vocab
self.side_data = side_data
# model
if 'model' in opt:
self.model = get_model(opt, opt['model'], self.device, vocab, side_data).to(self.device)
else:
if 'rec_model' in opt:
self.rec_model = get_model(opt, opt['rec_model'], self.device, vocab['rec'], side_data['rec']).to(
self.device)
if 'conv_model' in opt:
self.conv_model = get_model(opt, opt['conv_model'], self.device, vocab['conv'], side_data['conv']).to(
self.device)
if 'policy_model' in opt:
self.policy_model = get_model(opt, opt['policy_model'], self.device, vocab['policy'],
side_data['policy']).to(self.device)
model_file_name = opt.get('model_file', f'{opt["model_name"]}.pth')
self.model_file = os.path.join(SAVE_PATH, model_file_name)
if restore_system:
self.restore_model()
if not interact:
self.evaluator = get_evaluator('standard', opt['dataset'], opt['rankfile'])
def init_optim(self, opt, parameters):
self.optim_opt = opt
parameters = list(parameters)
if isinstance(parameters[0], dict):
for i, d in enumerate(parameters):
parameters[i]['params'] = list(d['params'])
# gradient acumulation
self.update_freq = opt.get('update_freq', 1)
self._number_grad_accum = 0
self.gradient_clip = opt.get('gradient_clip', -1)
self.build_optimizer(parameters)
self.build_lr_scheduler()
if isinstance(parameters[0], dict):
self.parameters = []
for d in parameters:
self.parameters.extend(d['params'])
else:
self.parameters = parameters
# early stop
self.need_early_stop = self.optim_opt.get('early_stop', False)
if self.need_early_stop:
logger.debug('[Enable early stop]')
self.reset_early_stop_state()
def build_optimizer(self, parameters):
optimizer_opt = self.optim_opt['optimizer']
optimizer = optimizer_opt.pop('name')
self.optimizer = optim_class[optimizer](parameters, **optimizer_opt)
logger.info(f"[Build optimizer: {optimizer}]")
def build_lr_scheduler(self):
"""
Create the learning rate scheduler, and assign it to self.scheduler. This
scheduler will be updated upon a call to receive_metrics. May also create
self.warmup_scheduler, if appropriate.
:param state_dict states: Possible state_dict provided by model
checkpoint, for restoring LR state
:param bool hard_reset: If true, the LR scheduler should ignore the
state dictionary.
"""
if self.optim_opt.get('lr_scheduler', None):
lr_scheduler_opt = self.optim_opt['lr_scheduler']
lr_scheduler = lr_scheduler_opt.pop('name')
self.scheduler = lr_scheduler_class[lr_scheduler](self.optimizer, **lr_scheduler_opt)
logger.info(f"[Build scheduler {lr_scheduler}]")
def reset_early_stop_state(self):
self.best_valid = None
self.drop_cnt = 0
self.impatience = self.optim_opt.get('impatience', 3)
if self.optim_opt['stop_mode'] == 'max':
self.stop_mode = 1
elif self.optim_opt['stop_mode'] == 'min':
self.stop_mode = -1
else:
raise
logger.debug('[Reset early stop state]')
@abstractmethod
def fit(self):
"""fit the whole system"""
pass
@abstractmethod
def step(self, batch, stage, mode):
"""calculate loss and prediction for batch data under certrain stage and mode
Args:
batch (dict or tuple): batch data
stage (str): recommendation/policy/conversation etc.
mode (str): train/valid/test
"""
pass
def backward(self, loss):
"""empty grad, backward loss and update params
Args:
loss (torch.Tensor):
"""
self._zero_grad()
if self.update_freq > 1:
self._number_grad_accum = (self._number_grad_accum + 1) % self.update_freq
loss /= self.update_freq
loss.backward(loss.clone().detach())
self._update_params()
def _zero_grad(self):
if self._number_grad_accum != 0:
# if we're accumulating gradients, don't actually zero things out yet.
return
self.optimizer.zero_grad()
def _update_params(self):
if self.update_freq > 1:
# we're doing gradient accumulation, so we don't only want to step
# every N updates instead
# self._number_grad_accum is updated in backward function
if self._number_grad_accum != 0:
return
if self.gradient_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters, self.gradient_clip
)
self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))
self.evaluator.optim_metrics.add(
'grad clip ratio',
AverageMetric(float(grad_norm > self.gradient_clip)),
)
else:
grad_norm = compute_grad_norm(self.parameters)
self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))
self.optimizer.step()
if hasattr(self, 'scheduler'):
self.scheduler.train_step()
def adjust_lr(self, metric=None):
"""adjust learning rate w/o metric by scheduler
Args:
metric (optional): Defaults to None.
"""
if not hasattr(self, 'scheduler') or self.scheduler is None:
return
self.scheduler.valid_step(metric)
logger.debug('[Adjust learning rate after valid epoch]')
def early_stop(self, metric):
if not self.need_early_stop:
return False
if self.best_valid is None or metric * self.stop_mode > self.best_valid * self.stop_mode:
self.best_valid = metric
self.drop_cnt = 0
logger.info('[Get new best model]')
return False
else:
self.drop_cnt += 1
if self.drop_cnt >= self.impatience:
logger.info('[Early stop]')
return True
def save_model(self):
r"""Store the model parameters."""
state = {}
if hasattr(self, 'model'):
state['model_state_dict'] = self.model.state_dict()
if hasattr(self, 'rec_model'):
state['rec_state_dict'] = self.rec_model.state_dict()
if hasattr(self, 'conv_model'):
state['conv_state_dict'] = self.conv_model.state_dict()
if hasattr(self, 'policy_model'):
state['policy_state_dict'] = self.policy_model.state_dict()
os.makedirs(SAVE_PATH, exist_ok=True)
torch.save(state, self.model_file)
logger.info(f'[Save model into {self.model_file}]')
def restore_model(self):
r"""Store the model parameters."""
if not os.path.exists(self.model_file):
raise ValueError(f'Saved model [{self.model_file}] does not exist')
checkpoint = torch.load(self.model_file, map_location=self.device)
if hasattr(self, 'model'):
self.model.load_state_dict(checkpoint['model_state_dict'])
if hasattr(self, 'rec_model'):
self.rec_model.load_state_dict(checkpoint['rec_state_dict'])
if hasattr(self, 'conv_model'):
self.conv_model.load_state_dict(checkpoint['conv_state_dict'])
if hasattr(self, 'policy_model'):
self.policy_model.load_state_dict(checkpoint['policy_state_dict'])
logger.info(f'[Restore model from {self.model_file}]')
@abstractmethod
def interact(self):
pass

View File

@@ -0,0 +1,176 @@
# -*- encoding: utf-8 -*-
# @Time : 2021/5/26
# @Author : Chenzhan Shang
# @email : czshang@outlook.com
import os
import json
from time import perf_counter
import torch
import pickle as pkl
from loguru import logger
from crslab.evaluator.metrics.base import AverageMetric
from crslab.evaluator.metrics.gen import PPLMetric
from crslab.system.base import BaseSystem
from crslab.system.utils.functions import ind2txt
class HyCoRecSystem(BaseSystem):
"""This is the system for KBRD model"""
def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
interact=False, debug=False):
"""
Args:
opt (dict): Indicating the hyper parameters.
train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.
valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.
test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.
vocab (dict): Indicating the vocabulary.
side_data (dict): Indicating the side data.
restore_system (bool, optional): Indicating if we store system after training. Defaults to False.
interact (bool, optional): Indicating if we interact with system. Defaults to False.
debug (bool, optional): Indicating if we train in debug mode. Defaults to False.
"""
super(HyCoRecSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,
restore_system, interact, debug)
self.ind2tok = vocab['ind2tok']
self.end_token_idx = vocab['tok2ind']['__end__']
self.item_ids = side_data['item_entity_ids']
self.rec_optim_opt = opt['rec']
self.conv_optim_opt = opt['conv']
self.rec_epoch = self.rec_optim_opt['epoch']
self.conv_epoch = self.conv_optim_opt['epoch']
self.rec_batch_size = self.rec_optim_opt['batch_size']
self.conv_batch_size = self.conv_optim_opt['batch_size']
def rec_evaluate(self, rec_predict, item_label):
rec_predict = rec_predict.cpu()
rec_predict = rec_predict[:, self.item_ids]
_, rec_ranks = torch.topk(rec_predict, 50, dim=-1)
rec_ranks = rec_ranks.tolist()
item_label = item_label.tolist()
# start = perf_counter()
for rec_rank, label in zip(rec_ranks, item_label):
label = self.item_ids.index(label)
self.evaluator.rec_evaluate(rec_rank, label)
# print(f"{perf_counter() - start}")
def conv_evaluate(self, prediction, response, batch_user_id=None, batch_conv_id=None):
prediction = prediction.tolist()
response = response.tolist()
if batch_user_id is None:
for p, r in zip(prediction, response):
p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
self.evaluator.gen_evaluate(p_str, [r_str], p)
else:
for p, r, uid, cid in zip(prediction, response, batch_user_id, batch_conv_id):
p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
self.evaluator.gen_evaluate(p_str, [r_str], p)
def step(self, batch, stage, mode):
assert stage in ('rec', 'conv')
assert mode in ('train', 'valid', 'test')
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)
if stage == 'rec':
rec_loss, rec_scores = self.model.forward(batch, mode, stage)
rec_loss = rec_loss.sum()
if mode == 'train':
self.backward(rec_loss)
else:
self.rec_evaluate(rec_scores, batch['item'])
rec_loss = rec_loss.item()
self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss))
else:
if mode != 'test':
gen_loss, preds = self.model.forward(batch, mode, stage)
if mode == 'train':
self.backward(gen_loss)
else:
self.conv_evaluate(preds, batch['response'])
gen_loss = gen_loss.item()
self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss))
self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
else:
preds = self.model.forward(batch, mode, stage)
self.conv_evaluate(preds, batch['response'], batch.get('user_id', None), batch['conv_id'])
def train_recommender(self):
self.init_optim(self.rec_optim_opt, self.model.parameters())
for epoch in range(self.rec_epoch):
self.evaluator.reset_metrics()
logger.info(f'[Recommendation epoch {str(epoch)}]')
logger.info('[Train]')
for batch in self.train_dataloader.get_rec_data(self.rec_batch_size):
self.step(batch, stage='rec', mode='train')
self.evaluator.report(epoch=epoch, mode='train')
# val
logger.info('[Valid]')
with torch.no_grad():
self.evaluator.reset_metrics()
for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):
self.step(batch, stage='rec', mode='valid')
self.evaluator.report(epoch=epoch, mode='valid')
# early stop
metric = self.evaluator.optim_metrics['rec_loss']
if self.early_stop(metric):
break
# test
logger.info('[Test]')
with torch.no_grad():
self.evaluator.reset_metrics()
for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):
self.step(batch, stage='rec', mode='test')
self.evaluator.report(mode='test')
def train_conversation(self):
if os.environ["CUDA_VISIBLE_DEVICES"] == '-1':
self.model.freeze_parameters()
else:
self.model.module.freeze_parameters()
self.init_optim(self.conv_optim_opt, self.model.parameters())
for epoch in range(self.conv_epoch):
self.evaluator.reset_metrics()
logger.info(f'[Conversation epoch {str(epoch)}]')
logger.info('[Train]')
for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size):
self.step(batch, stage='conv', mode='train')
self.evaluator.report(epoch=epoch, mode='train')
# val
logger.info('[Valid]')
with torch.no_grad():
self.evaluator.reset_metrics()
for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):
self.step(batch, stage='conv', mode='valid')
self.evaluator.report(epoch=epoch, mode='valid')
# early stop
metric = self.evaluator.optim_metrics['gen_loss']
if self.early_stop(metric):
break
# test
logger.info('[Test]')
with torch.no_grad():
self.evaluator.reset_metrics()
for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):
self.step(batch, stage='conv', mode='test')
self.evaluator.report(mode='test')
def fit(self):
self.train_recommender()
self.train_conversation()
def interact(self):
pass

View File

View File

@@ -0,0 +1,66 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/18
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
# UPDATE:
# @Time : 2021/10/05
# @Author : Zhipeng Zhao
# @email : oran_official@outlook.com
import torch
def compute_grad_norm(parameters, norm_type=2.0):
"""
Compute norm over gradients of model parameters.
:param parameters:
the model parameters for gradient norm calculation. Iterable of
Tensors or single Tensor
:param norm_type:
type of p-norm to use
:returns:
the computed gradient norm
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p is not None and p.grad is not None]
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
return total_norm ** (1.0 / norm_type)
def ind2txt(inds, ind2tok, end_token_idx=None, unk_token='unk'):
sentence = []
for ind in inds:
if isinstance(ind, torch.Tensor):
ind = ind.item()
if end_token_idx and ind == end_token_idx:
break
sentence.append(ind2tok.get(ind, unk_token))
return ' '.join(sentence)
def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',slot_token='[ITEM]'):
sentence = []
for ind in inds:
if isinstance(ind, torch.Tensor):
ind = ind.item()
if end_token_idx and ind == end_token_idx:
break
token = ind2tok.get(ind, unk_token)
if token == slot_token:
token = slots[0]
slots = slots[1:]
sentence.append(token)
return ' '.join(sentence)
def ind2slot(inds,ind2slot):
return [ ind2slot[ind] for ind in inds]

View File

@@ -0,0 +1,316 @@
# @Time : 2020/12/1
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
from abc import abstractmethod, ABC
# UPDATE:
# @Time : 2020/12/14
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
import math
import numpy as np
import torch
from loguru import logger
from torch import optim
class LRScheduler(ABC):
"""
Class for LR Schedulers.
Includes some basic functionality by default - setting up the warmup
scheduler, passing the correct number of steps to train_step, loading and
saving states.
Subclasses must implement abstract methods train_step() and valid_step().
Schedulers should be initialized with lr_scheduler_factory().
__init__() should not be called directly.
"""
def __init__(self, optimizer, warmup_steps: int = 0):
"""
Initialize warmup scheduler. Specific main schedulers should be initialized in
the subclasses. Do not invoke this method diretly.
:param optimizer optimizer:
Optimizer being used for training. May be wrapped in
fp16_optimizer_wrapper depending on whether fp16 is used.
:param int warmup_steps:
Number of training step updates warmup scheduler should take.
"""
self._number_training_updates = 0
self.warmup_steps = warmup_steps
self._init_warmup_scheduler(optimizer)
def _warmup_lr(self, step):
"""
Return lr multiplier (on initial lr) for warmup scheduler.
"""
return float(step) / float(max(1, self.warmup_steps))
def _init_warmup_scheduler(self, optimizer):
if self.warmup_steps > 0:
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._warmup_lr)
else:
self.warmup_scheduler = None
def _is_lr_warming_up(self):
"""
Check if we're warming up the learning rate.
"""
return (
hasattr(self, 'warmup_scheduler')
and self.warmup_scheduler is not None
and self._number_training_updates <= self.warmup_steps
)
def train_step(self):
"""
Use the number of train steps to adjust the warmup scheduler or the main
scheduler, depending on where in training we are.
Override this method to override the behavior for training schedulers.
"""
self._number_training_updates += 1
if self._is_lr_warming_up():
self.warmup_scheduler.step()
else:
self.train_adjust()
def valid_step(self, metric=None):
if self._is_lr_warming_up():
# we're not done warming up, so don't start using validation
# metrics to adjust schedule
return
self.valid_adjust(metric)
@abstractmethod
def train_adjust(self):
"""
Use the number of train steps to decide when to adjust LR schedule.
Override this method to override the behavior for training schedulers.
"""
pass
@abstractmethod
def valid_adjust(self, metric):
"""
Use the metrics to decide when to adjust LR schedule.
This uses the loss as the validation metric if present, if not this
function does nothing. Note that the model must be reporting loss for
this to work.
Override this method to override the behavior for validation schedulers.
"""
pass
class ReduceLROnPlateau(LRScheduler):
"""
Scheduler that decays by a multiplicative rate when valid loss plateaus.
"""
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,
threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, warmup_steps=0):
super(ReduceLROnPlateau, self).__init__(optimizer, warmup_steps)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor,
patience=patience, threshold=threshold,
threshold_mode=threshold_mode, cooldown=cooldown,
min_lr=min_lr, eps=eps)
def train_adjust(self):
pass
def valid_adjust(self, metric):
self.scheduler.step(metric)
class StepLR(LRScheduler):
"""
Scheduler that decays by a fixed multiplicative rate at each valid step.
"""
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, warmup_steps=0):
super(StepLR, self).__init__(optimizer, warmup_steps)
self.scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma, last_epoch)
def train_adjust(self):
pass
def valid_adjust(self, metric=None):
self.scheduler.step()
class ConstantLR(LRScheduler):
def __init__(self, optimizer, warmup_steps=0):
super(ConstantLR, self).__init__(optimizer, warmup_steps)
def train_adjust(self):
pass
def valid_adjust(self, metric):
pass
class InvSqrtLR(LRScheduler):
"""
Scheduler that decays at an inverse square root rate.
"""
def __init__(self, optimizer, invsqrt_lr_decay_gamma=-1, last_epoch=-1, warmup_steps=0):
"""
invsqrt_lr_decay_gamma determines the cycle length of the inverse square root
scheduler.
When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1
"""
super(InvSqrtLR, self).__init__(optimizer, warmup_steps)
self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma
if invsqrt_lr_decay_gamma <= 0:
logger.warning(
'--lr-scheduler invsqrt requires a value for '
'--invsqrt-lr-decay-gamma. Defaulting to set gamma to '
'--warmup-updates value for backwards compatibility.'
)
self.invsqrt_lr_decay_gamma = self.warmup_steps
self.decay_factor = np.sqrt(max(1, self.invsqrt_lr_decay_gamma))
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._invsqrt_lr, last_epoch)
def _invsqrt_lr(self, step):
return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step))
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
# this is a training step lr scheduler, nothing to adjust in validation
pass
class CosineAnnealingLR(LRScheduler):
"""
Scheduler that decays by a cosine function.
"""
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, warmup_steps=0):
"""
training_steps determines the cycle length of the cosine annealing.
It indicates the number of steps from 1.0 multiplier to 0.0, which corresponds
to going from cos(0) to cos(pi)
"""
super(CosineAnnealingLR, self).__init__(optimizer, warmup_steps)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min, last_epoch)
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass
class CosineAnnealingWarmRestartsLR(LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warmup_steps=0):
super(CosineAnnealingWarmRestartsLR, self).__init__(optimizer, warmup_steps)
self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult, eta_min, last_epoch)
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass
class TransformersLinearLR(LRScheduler):
"""
Scheduler that decays linearly.
"""
def __init__(self, optimizer, training_steps, warmup_steps=0):
"""
training_steps determines the cycle length of the linear annealing.
It indicates the number of steps from 1.0 multiplier to 0.0
"""
super(TransformersLinearLR, self).__init__(optimizer, warmup_steps)
self.training_steps = training_steps - warmup_steps
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr)
def _linear_lr(self, step):
return max(0.0, float(self.training_steps - step) / float(max(1, self.training_steps)))
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass
class TransformersCosineLR(LRScheduler):
def __init__(self, optimizer, training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1,
warmup_steps: int = 0):
super(TransformersCosineLR, self).__init__(optimizer, warmup_steps)
self.training_steps = training_steps - warmup_steps
self.num_cycles = num_cycles
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr, last_epoch)
def _cosine_lr(self, step):
progress = float(step) / float(max(1, self.training_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass
class TransformersCosineWithHardRestartsLR(LRScheduler):
def __init__(self, optimizer, training_steps: int, num_cycles: int = 1, last_epoch: int = -1,
warmup_steps: int = 0):
super(TransformersCosineWithHardRestartsLR, self).__init__(optimizer, warmup_steps)
self.training_steps = training_steps - warmup_steps
self.num_cycles = num_cycles
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_with_hard_restarts_lr, last_epoch)
def _cosine_with_hard_restarts_lr(self, step):
progress = float(step) / float(max(1, self.training_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(self.num_cycles) * progress) % 1.0))))
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass
class TransformersPolynomialDecayLR(LRScheduler):
def __init__(self, optimizer, training_steps, lr_end=1e-7, power=1.0, last_epoch=-1, warmup_steps=0):
super(TransformersPolynomialDecayLR, self).__init__(optimizer, warmup_steps)
self.training_steps = training_steps - warmup_steps
self.lr_init = optimizer.defaults["lr"]
self.lr_end = lr_end
assert self.lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({self.lr_init})"
self.power = power
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._polynomial_decay_lr, last_epoch)
def _polynomial_decay_lr(self, step):
if step > self.training_steps:
return self.lr_end / self.lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = self.lr_init - self.lr_end
decay_steps = self.training_steps
pct_remaining = 1 - step / decay_steps
decay = lr_range * pct_remaining ** self.power + self.lr_end
return decay / self.lr_init # as LambdaLR multiplies by lr_init
def train_adjust(self):
self.scheduler.step()
def valid_adjust(self, metric):
pass

45
HyCoRec/run_crslab.py Normal file
View File

@@ -0,0 +1,45 @@
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2021/1/9
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
import argparse
import warnings
from crslab.config import Config
warnings.filterwarnings('ignore')
if __name__ == '__main__':
# parse args
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
default='config/crs/hycorec/hredial.yaml', help='config file(yaml) path')
parser.add_argument('-g', '--gpu', type=str, default='-1',
help='specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).')
parser.add_argument('-sd', '--save_data', action='store_true',
help='save processed dataset')
parser.add_argument('-rd', '--restore_data', action='store_true',
help='restore processed dataset')
parser.add_argument('-ss', '--save_system', action='store_true',
help='save trained system')
parser.add_argument('-rs', '--restore_system', action='store_true',
help='restore trained system')
parser.add_argument('-d', '--debug', action='store_true',
help='use valid dataset to debug your system')
parser.add_argument('-i', '--interact', action='store_true',
help='interact with your system instead of training')
parser.add_argument('-s', '--seed', type=int, default=2020)
parser.add_argument('-p', '--pretrain', action='store_true', help='use pretrain weights')
parser.add_argument('-e', '--pretrain_epoch', type=int, default=9999, help='pretrain epoch')
args, _ = parser.parse_known_args()
config = Config(args.config, args.gpu, args.debug, args.seed, args.pretrain, args.pretrain_epoch)
from crslab.quick_start import run_crslab
run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact,
args.debug)

0
README.md Normal file
View File

23
pyproject.toml Normal file
View File

@@ -0,0 +1,23 @@
[project]
name = "HyCoRec"
version = "0.1.0"
description = "The implementation of HyCoRec: Hypergraph-Enhanced Multi-Preference Learning for Alleviating Matthew Effect in Conversational Recommendation (ACL 2024)"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"argparse>=1.4.0",
"dgl>=2.2.1",
"loguru>=0.7.3",
"nltk>=3.9.1",
"pyyaml>=6.0.2",
"requests>=2.32.4",
"scikit-learn>=1.7.1",
"torch>=2.8.0",
"torch-geometric>=2.6.1",
"tqdm>=4.67.1",
"transformers>=4.55.0",
]
[[tool.uv.index]]
url="https://pypi.tuna.tsinghua.edu.cn/simple"
default=true