init HyCoRec
This commit is contained in:
50
HyCoRec/config/crs/hycorec/durecdial.yaml
Normal file
50
HyCoRec/config/crs/hycorec/durecdial.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# dataset
|
||||
dataset: DuRecDial
|
||||
tokenize: jieba
|
||||
# dataloader
|
||||
related_truncate: 1024
|
||||
context_truncate: 256
|
||||
response_truncate: 30
|
||||
scale: 1
|
||||
# model
|
||||
model: HyCoRec
|
||||
token_emb_dim: 300
|
||||
kg_emb_dim: 128
|
||||
num_bases: 8
|
||||
n_heads: 2
|
||||
n_layers: 2
|
||||
ffn_size: 300
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
relu_dropout: 0.1
|
||||
learn_positional_embeddings: false
|
||||
embeddings_scale: true
|
||||
reduction: false
|
||||
n_positions: 1024
|
||||
user_proj_dim: 512
|
||||
# HyCoRec-CHANGE
|
||||
mha_n_heads: 4
|
||||
pooling: Attn
|
||||
extension_strategy: Adaptive
|
||||
# optim
|
||||
rec:
|
||||
epoch: 1
|
||||
batch_size: 64
|
||||
early_stop: False
|
||||
stop_mode: min
|
||||
impatience: 2
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
conv:
|
||||
epoch: 1
|
||||
batch_size: 128
|
||||
impatience: 1
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
lr_scheduler:
|
||||
name: ReduceLROnPlateau
|
||||
patience: 3
|
||||
factor: 0.5
|
||||
gradient_clip: 0.1
|
50
HyCoRec/config/crs/hycorec/opendialkg.yaml
Normal file
50
HyCoRec/config/crs/hycorec/opendialkg.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# dataset
|
||||
dataset: OpenDialKG
|
||||
tokenize: nltk
|
||||
# dataloader
|
||||
related_truncate: 1024
|
||||
context_truncate: 256
|
||||
response_truncate: 30
|
||||
scale: 1
|
||||
# model
|
||||
model: HyCoRec
|
||||
token_emb_dim: 300
|
||||
kg_emb_dim: 128
|
||||
num_bases: 8
|
||||
n_heads: 2
|
||||
n_layers: 2
|
||||
ffn_size: 300
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
relu_dropout: 0.1
|
||||
learn_positional_embeddings: false
|
||||
embeddings_scale: true
|
||||
reduction: false
|
||||
n_positions: 1024
|
||||
user_proj_dim: 512
|
||||
# HyCoRec-CHANGE
|
||||
mha_n_heads: 4
|
||||
pooling: Mean
|
||||
extension_strategy: Adaptive
|
||||
# optim
|
||||
rec:
|
||||
epoch: 1
|
||||
batch_size: 256
|
||||
early_stop: False
|
||||
stop_mode: min
|
||||
impatience: 2
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
conv:
|
||||
epoch: 1
|
||||
batch_size: 128
|
||||
impatience: 1
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
lr_scheduler:
|
||||
name: ReduceLROnPlateau
|
||||
patience: 3
|
||||
factor: 0.5
|
||||
gradient_clip: 0.1
|
50
HyCoRec/config/crs/hycorec/redial.yaml
Normal file
50
HyCoRec/config/crs/hycorec/redial.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# dataset
|
||||
dataset: ReDial
|
||||
tokenize: nltk
|
||||
# dataloader
|
||||
related_truncate: 1024
|
||||
context_truncate: 256
|
||||
response_truncate: 30
|
||||
scale: 1
|
||||
# model
|
||||
model: HyCoRec
|
||||
token_emb_dim: 300
|
||||
kg_emb_dim: 128
|
||||
num_bases: 8
|
||||
n_heads: 2
|
||||
n_layers: 2
|
||||
ffn_size: 300
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
relu_dropout: 0.1
|
||||
learn_positional_embeddings: false
|
||||
embeddings_scale: true
|
||||
reduction: false
|
||||
n_positions: 1024
|
||||
user_proj_dim: 512
|
||||
# HyCoRec-CHANGE
|
||||
mha_n_heads: 4
|
||||
pooling: Mean
|
||||
extension_strategy: Adaptive
|
||||
# optim
|
||||
rec:
|
||||
epoch: 1
|
||||
batch_size: 256
|
||||
early_stop: False
|
||||
stop_mode: min
|
||||
impatience: 2
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
conv:
|
||||
epoch: 1
|
||||
batch_size: 128
|
||||
impatience: 1
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
lr_scheduler:
|
||||
name: ReduceLROnPlateau
|
||||
patience: 3
|
||||
factor: 0.5
|
||||
gradient_clip: 0.1
|
50
HyCoRec/config/crs/hycorec/tgredial.yaml
Normal file
50
HyCoRec/config/crs/hycorec/tgredial.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# dataset
|
||||
dataset: TGReDial
|
||||
tokenize: pkuseg
|
||||
# dataloader
|
||||
related_truncate: 1024
|
||||
context_truncate: 256
|
||||
response_truncate: 30
|
||||
scale: 1
|
||||
# model
|
||||
model: HyCoRec
|
||||
token_emb_dim: 300
|
||||
kg_emb_dim: 128
|
||||
num_bases: 8
|
||||
n_heads: 2
|
||||
n_layers: 2
|
||||
ffn_size: 300
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
relu_dropout: 0.1
|
||||
learn_positional_embeddings: false
|
||||
embeddings_scale: true
|
||||
reduction: false
|
||||
n_positions: 1024
|
||||
user_proj_dim: 512
|
||||
# HyCoRec-CHANGE
|
||||
mha_n_heads: 4
|
||||
pooling: Attn
|
||||
extension_strategy: Adaptive
|
||||
# optim
|
||||
rec:
|
||||
epoch: 1
|
||||
batch_size: 64
|
||||
early_stop: False
|
||||
stop_mode: min
|
||||
impatience: 2
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
conv:
|
||||
epoch: 1
|
||||
batch_size: 128
|
||||
impatience: 1
|
||||
optimizer:
|
||||
name: Adam
|
||||
lr: !!float 1e-3
|
||||
lr_scheduler:
|
||||
name: ReduceLROnPlateau
|
||||
patience: 3
|
||||
factor: 0.5
|
||||
gradient_clip: 0.1
|
1
HyCoRec/crslab/__init__.py
Normal file
1
HyCoRec/crslab/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
32
HyCoRec/crslab/config/__init__.py
Normal file
32
HyCoRec/crslab/config/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/29
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
"""Config module which loads parameters for the whole system.
|
||||
|
||||
Attributes:
|
||||
SAVE_PATH (str): where system to save.
|
||||
DATASET_PATH (str): where dataset to save.
|
||||
MODEL_PATH (str): where model related data to save.
|
||||
PRETRAIN_PATH (str): where pretrained model to save.
|
||||
EMBEDDING_PATH (str): where pretrained embedding to save, used for evaluate embedding related metrics.
|
||||
"""
|
||||
|
||||
import os
|
||||
from os.path import dirname, realpath
|
||||
|
||||
from .config import Config
|
||||
|
||||
ROOT_PATH = dirname(dirname(dirname(realpath(__file__))))
|
||||
SAVE_PATH = os.path.join(ROOT_PATH, 'save')
|
||||
DATA_PATH = os.path.join(ROOT_PATH, 'data')
|
||||
DATASET_PATH = os.path.join(DATA_PATH, 'dataset')
|
||||
MODEL_PATH = os.path.join(DATA_PATH, 'model')
|
||||
PRETRAIN_PATH = os.path.join(MODEL_PATH, 'pretrain')
|
||||
EMBEDDING_PATH = os.path.join(DATA_PATH, 'embedding')
|
155
HyCoRec/crslab/config/config.py
Normal file
155
HyCoRec/crslab/config/config.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2021/1/9
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pprint import pprint
|
||||
|
||||
import yaml
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configurator module that load the defined parameters."""
|
||||
|
||||
def __init__(self, config_file, gpu='-1', debug=False, seed=2020, pretrain=False, pretrain_epoch=None):
|
||||
"""Load parameters and set log level.
|
||||
|
||||
Args:
|
||||
config_file (str): path to the config file, which should be in ``yaml`` format.
|
||||
You can use default config provided in the `Github repo`_, or write it by yourself.
|
||||
debug (bool, optional): whether to enable debug function during running. Defaults to False.
|
||||
|
||||
.. _Github repo:
|
||||
https://github.com/RUCAIBox/CRSLab
|
||||
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
self.opt = self.load_yaml_configs(config_file)
|
||||
# pretrain epoch
|
||||
self.opt['pretrain'] = pretrain
|
||||
self.opt['pretrain_epoch'] = pretrain_epoch
|
||||
# gpu
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
|
||||
self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] if gpu != '-1' else [-1]
|
||||
# dataset
|
||||
dataset = self.opt['dataset']
|
||||
tokenize = self.opt['tokenize']
|
||||
if isinstance(tokenize, dict):
|
||||
tokenize = ', '.join(tokenize.values())
|
||||
# model
|
||||
model = self.opt.get('model', None)
|
||||
rec_model = self.opt.get('rec_model', None)
|
||||
conv_model = self.opt.get('conv_model', None)
|
||||
policy_model = self.opt.get('policy_model', None)
|
||||
if model:
|
||||
model_name = model
|
||||
else:
|
||||
models = []
|
||||
if rec_model:
|
||||
models.append(rec_model)
|
||||
if conv_model:
|
||||
models.append(conv_model)
|
||||
if policy_model:
|
||||
models.append(policy_model)
|
||||
model_name = '_'.join(models)
|
||||
self.opt['model_name'] = model_name
|
||||
self.opt['rankfile'] = f"{model_name}-{dataset}.json"
|
||||
# log
|
||||
log_name = self.opt.get("log_name", dataset + '_' + model_name + '_' + time.strftime("%Y-%m-%d-%H-%M-%S",
|
||||
time.localtime())) + ".log"
|
||||
if not os.path.exists("log"):
|
||||
os.makedirs("log")
|
||||
logger.remove()
|
||||
if debug:
|
||||
level = 'DEBUG'
|
||||
else:
|
||||
level = 'INFO'
|
||||
logger.add(os.path.join("log", log_name), level=level)
|
||||
logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level)
|
||||
|
||||
logger.info(f"[Dataset: {dataset} tokenized in {tokenize}]")
|
||||
if model:
|
||||
logger.info(f'[Model: {model}]')
|
||||
if rec_model:
|
||||
logger.info(f'[Recommendation Model: {rec_model}]')
|
||||
if conv_model:
|
||||
logger.info(f'[Conversation Model: {conv_model}]')
|
||||
if policy_model:
|
||||
logger.info(f'[Policy Model: {policy_model}]')
|
||||
logger.info("[Config]" + '/n' + json.dumps(self.opt, indent=4))
|
||||
|
||||
@staticmethod
|
||||
def load_yaml_configs(filename):
|
||||
"""This function reads ``yaml`` file to build config dictionary
|
||||
|
||||
Args:
|
||||
filename (str): path to ``yaml`` config
|
||||
|
||||
Returns:
|
||||
dict: config
|
||||
|
||||
"""
|
||||
config_dict = dict()
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
config_dict.update(yaml.safe_load(f.read()))
|
||||
return config_dict
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("index must be a str.")
|
||||
self.opt[key] = value
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in self.opt:
|
||||
return self.opt[item]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get(self, item, default=None):
|
||||
"""Get value of corrsponding item in config
|
||||
|
||||
Args:
|
||||
item (str): key to query in config
|
||||
default (optional): default value for item if not found in config. Defaults to None.
|
||||
|
||||
Returns:
|
||||
value of corrsponding item in config
|
||||
|
||||
"""
|
||||
if item in self.opt:
|
||||
return self.opt[item]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("index must be a str.")
|
||||
return key in self.opt
|
||||
|
||||
def __str__(self):
|
||||
return str(self.opt)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt_dict = Config('config/crs/hycorec/hredial.yaml')
|
||||
pprint(opt_dict)
|
85
HyCoRec/crslab/data/__init__.py
Normal file
85
HyCoRec/crslab/data/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/29, 2020/12/17
|
||||
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com
|
||||
|
||||
# @Time : 2021/10/06
|
||||
# @Author : Zhipeng Zhao
|
||||
# @Email : oran_official@outlook.com
|
||||
|
||||
"""Data module which reads, processes and batches data for the whole system
|
||||
|
||||
Attributes:
|
||||
dataset_register_table (dict): record all supported dataset
|
||||
dataset_language_map (dict): record all dataset corresponding language
|
||||
dataloader_register_table (dict): record all model corresponding dataloader
|
||||
|
||||
"""
|
||||
|
||||
from crslab.data.dataloader import *
|
||||
from crslab.data.dataset import *
|
||||
|
||||
dataset_register_table = {
|
||||
'HReDial': HReDialDataset,
|
||||
'HTGReDial': HTGReDialDataset,
|
||||
'OpenDialKG': OpenDialKGDataset,
|
||||
'DuRecDial': DuRecDialDataset,
|
||||
'ReDial': ReDialDataset,
|
||||
'TGReDial': TGReDialDataset,
|
||||
}
|
||||
|
||||
dataset_language_map = {
|
||||
'ReDial': 'en',
|
||||
'TGReDial': 'zh',
|
||||
'HReDial': 'en',
|
||||
'HTGReDial': 'zh',
|
||||
'OpenDialKG': 'en',
|
||||
'DuRecDial': 'zh',
|
||||
}
|
||||
|
||||
dataloader_register_table = {
|
||||
'HyCoRec': HyCoRecDataLoader,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset(opt, tokenize, restore, save) -> BaseDataset:
|
||||
"""get and process dataset
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize the dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed.
|
||||
save (bool): whether to save dataset after processing.
|
||||
|
||||
Returns:
|
||||
processed dataset
|
||||
|
||||
"""
|
||||
dataset = opt['dataset']
|
||||
if dataset in dataset_register_table:
|
||||
return dataset_register_table[dataset](opt, tokenize, restore, save)
|
||||
else:
|
||||
raise NotImplementedError(f'The dataloader [{dataset}] has not been implemented')
|
||||
|
||||
|
||||
def get_dataloader(opt, dataset, vocab) -> BaseDataLoader:
|
||||
"""get dataloader to batchify dataset
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataloader or the whole system.
|
||||
dataset: processed raw data, no side data.
|
||||
vocab (dict): all kinds of useful size, idx and map between token and idx.
|
||||
|
||||
Returns:
|
||||
dataloader
|
||||
|
||||
"""
|
||||
model_name = opt['model_name']
|
||||
if model_name in dataloader_register_table:
|
||||
return dataloader_register_table[model_name](opt, dataset, vocab)
|
||||
else:
|
||||
raise NotImplementedError(f'The dataloader [{model_name}] has not been implemented')
|
2
HyCoRec/crslab/data/dataloader/__init__.py
Normal file
2
HyCoRec/crslab/data/dataloader/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseDataLoader
|
||||
from .hycorec import HyCoRecDataLoader
|
211
HyCoRec/crslab/data/dataloader/base.py
Normal file
211
HyCoRec/crslab/data/dataloader/base.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2020/12/29
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
import random
|
||||
from abc import ABC
|
||||
|
||||
from loguru import logger
|
||||
from math import ceil
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class BaseDataLoader(ABC):
|
||||
"""Abstract class of dataloader
|
||||
|
||||
Notes:
|
||||
``'scale'`` can be set in config to limit the size of dataset.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, dataset):
|
||||
"""
|
||||
Args:
|
||||
opt (Config or dict): config for dataloader or the whole system.
|
||||
dataset: dataset
|
||||
|
||||
"""
|
||||
self.opt = opt
|
||||
self.dataset = dataset
|
||||
self.scale = opt.get('scale', 1)
|
||||
assert 0 < self.scale <= 1
|
||||
|
||||
def get_data(self, batch_fn, batch_size, shuffle=True, process_fn=None):
|
||||
"""Collate batch data for system to fit
|
||||
|
||||
Args:
|
||||
batch_fn (func): function to collate data
|
||||
batch_size (int):
|
||||
shuffle (bool, optional): Defaults to True.
|
||||
process_fn (func, optional): function to process dataset before batchify. Defaults to None.
|
||||
|
||||
Yields:
|
||||
tuple or dict of torch.Tensor: batch data for system to fit
|
||||
|
||||
"""
|
||||
dataset = self.dataset
|
||||
if process_fn is not None:
|
||||
dataset = process_fn()
|
||||
# logger.info('[Finish dataset process before batchify]')
|
||||
dataset = dataset[:ceil(len(dataset) * self.scale)]
|
||||
logger.debug(f'[Dataset size: {len(dataset)}]')
|
||||
|
||||
batch_num = ceil(len(dataset) / batch_size)
|
||||
idx_list = list(range(len(dataset)))
|
||||
if shuffle:
|
||||
random.shuffle(idx_list)
|
||||
|
||||
for start_idx in tqdm(range(batch_num)):
|
||||
batch_idx = idx_list[start_idx * batch_size: (start_idx + 1) * batch_size]
|
||||
batch = [dataset[idx] for idx in batch_idx]
|
||||
batch = batch_fn(batch)
|
||||
if batch == False:
|
||||
continue
|
||||
else:
|
||||
# print(batch)
|
||||
yield(batch)
|
||||
|
||||
def get_conv_data(self, batch_size, shuffle=True):
|
||||
"""get_data wrapper for conversation.
|
||||
|
||||
You can implement your own process_fn in ``conv_process_fn``, batch_fn in ``conv_batchify``.
|
||||
|
||||
Args:
|
||||
batch_size (int):
|
||||
shuffle (bool, optional): Defaults to True.
|
||||
|
||||
Yields:
|
||||
tuple or dict of torch.Tensor: batch data for conversation.
|
||||
|
||||
"""
|
||||
return self.get_data(self.conv_batchify, batch_size, shuffle, self.conv_process_fn)
|
||||
|
||||
def get_rec_data(self, batch_size, shuffle=True):
|
||||
"""get_data wrapper for recommendation.
|
||||
|
||||
You can implement your own process_fn in ``rec_process_fn``, batch_fn in ``rec_batchify``.
|
||||
|
||||
Args:
|
||||
batch_size (int):
|
||||
shuffle (bool, optional): Defaults to True.
|
||||
|
||||
Yields:
|
||||
tuple or dict of torch.Tensor: batch data for recommendation.
|
||||
|
||||
"""
|
||||
return self.get_data(self.rec_batchify, batch_size, shuffle, self.rec_process_fn)
|
||||
|
||||
def get_policy_data(self, batch_size, shuffle=True):
|
||||
"""get_data wrapper for policy.
|
||||
|
||||
You can implement your own process_fn in ``self.policy_process_fn``, batch_fn in ``policy_batchify``.
|
||||
|
||||
Args:
|
||||
batch_size (int):
|
||||
shuffle (bool, optional): Defaults to True.
|
||||
|
||||
Yields:
|
||||
tuple or dict of torch.Tensor: batch data for policy.
|
||||
|
||||
"""
|
||||
return self.get_data(self.policy_batchify, batch_size, shuffle, self.policy_process_fn)
|
||||
|
||||
def conv_process_fn(self):
|
||||
"""Process whole data for conversation before batch_fn.
|
||||
|
||||
Returns:
|
||||
processed dataset. Defaults to return the same as `self.dataset`.
|
||||
|
||||
"""
|
||||
return self.dataset
|
||||
|
||||
def conv_batchify(self, batch):
|
||||
"""batchify data for conversation after process.
|
||||
|
||||
Args:
|
||||
batch (list): processed batch dataset.
|
||||
|
||||
Returns:
|
||||
batch data for the system to train conversation part.
|
||||
"""
|
||||
raise NotImplementedError('dataloader must implement conv_batchify() method')
|
||||
|
||||
def rec_process_fn(self):
|
||||
"""Process whole data for recommendation before batch_fn.
|
||||
|
||||
Returns:
|
||||
processed dataset. Defaults to return the same as `self.dataset`.
|
||||
|
||||
"""
|
||||
return self.dataset
|
||||
|
||||
def rec_batchify(self, batch):
|
||||
"""batchify data for recommendation after process.
|
||||
|
||||
Args:
|
||||
batch (list): processed batch dataset.
|
||||
|
||||
Returns:
|
||||
batch data for the system to train recommendation part.
|
||||
"""
|
||||
raise NotImplementedError('dataloader must implement rec_batchify() method')
|
||||
|
||||
def policy_process_fn(self):
|
||||
"""Process whole data for policy before batch_fn.
|
||||
|
||||
Returns:
|
||||
processed dataset. Defaults to return the same as `self.dataset`.
|
||||
|
||||
"""
|
||||
return self.dataset
|
||||
|
||||
def policy_batchify(self, batch):
|
||||
"""batchify data for policy after process.
|
||||
|
||||
Args:
|
||||
batch (list): processed batch dataset.
|
||||
|
||||
Returns:
|
||||
batch data for the system to train policy part.
|
||||
"""
|
||||
raise NotImplementedError('dataloader must implement policy_batchify() method')
|
||||
|
||||
def retain_recommender_target(self):
|
||||
"""keep data whose role is recommender.
|
||||
|
||||
Returns:
|
||||
Recommender part of ``self.dataset``.
|
||||
|
||||
"""
|
||||
dataset = []
|
||||
for conv_dict in tqdm(self.dataset):
|
||||
if conv_dict['role'] == 'Recommender':
|
||||
dataset.append(conv_dict)
|
||||
return dataset
|
||||
|
||||
def rec_interact(self, data):
|
||||
"""process user input data for system to recommend.
|
||||
|
||||
Args:
|
||||
data: user input data.
|
||||
|
||||
Returns:
|
||||
data for system to recommend.
|
||||
"""
|
||||
pass
|
||||
|
||||
def conv_interact(self, data):
|
||||
"""Process user input data for system to converse.
|
||||
|
||||
Args:
|
||||
data: user input data.
|
||||
|
||||
Returns:
|
||||
data for system in converse.
|
||||
"""
|
||||
pass
|
144
HyCoRec/crslab/data/dataloader/hycorec.py
Normal file
144
HyCoRec/crslab/data/dataloader/hycorec.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2021/5/26
|
||||
# @Author : Chenzhan Shang
|
||||
# @email : czshang@outlook.com
|
||||
|
||||
import pickle
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.data.dataloader.base import BaseDataLoader
|
||||
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt
|
||||
|
||||
|
||||
class HyCoRecDataLoader(BaseDataLoader):
|
||||
"""Dataloader for model KBRD.
|
||||
|
||||
Notes:
|
||||
You can set the following parameters in config:
|
||||
|
||||
- ``"context_truncate"``: the maximum length of context.
|
||||
- ``"response_truncate"``: the maximum length of response.
|
||||
- ``"entity_truncate"``: the maximum length of mentioned entities in context.
|
||||
|
||||
The following values must be specified in ``vocab``:
|
||||
|
||||
- ``"pad"``
|
||||
- ``"start"``
|
||||
- ``"end"``
|
||||
- ``"pad_entity"``
|
||||
|
||||
the above values specify the id of needed special token.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, dataset, vocab):
|
||||
"""
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataloader or the whole system.
|
||||
dataset: data for model.
|
||||
vocab (dict): all kinds of useful size, idx and map between token and idx.
|
||||
|
||||
"""
|
||||
super().__init__(opt, dataset)
|
||||
self.pad_token_idx = vocab["tok2ind"]["__pad__"]
|
||||
self.start_token_idx = vocab["tok2ind"]["__start__"]
|
||||
self.end_token_idx = vocab["tok2ind"]["__end__"]
|
||||
self.split_token_idx = vocab["tok2ind"].get("_split_", None)
|
||||
self.related_truncate = opt.get("related_truncate", None)
|
||||
self.context_truncate = opt.get("context_truncate", None)
|
||||
self.response_truncate = opt.get("response_truncate", None)
|
||||
self.entity_truncate = opt.get("entity_truncate", None)
|
||||
self.review_entity2id = vocab["entity2id"]
|
||||
return
|
||||
|
||||
def rec_process_fn(self):
|
||||
augment_dataset = []
|
||||
for conv_dict in tqdm(self.dataset):
|
||||
if conv_dict["role"] == "Recommender":
|
||||
for item in conv_dict["items"]:
|
||||
augment_conv_dict = {
|
||||
"conv_id": conv_dict["conv_id"],
|
||||
"related_item": conv_dict["item"],
|
||||
"related_entity": conv_dict["entity"],
|
||||
"related_word": conv_dict["word"],
|
||||
"item": item,
|
||||
}
|
||||
augment_dataset.append(augment_conv_dict)
|
||||
|
||||
return augment_dataset
|
||||
|
||||
def rec_batchify(self, batch):
|
||||
batch_related_item = []
|
||||
batch_related_entity = []
|
||||
batch_related_word = []
|
||||
batch_movies = []
|
||||
batch_conv_id = []
|
||||
for conv_dict in batch:
|
||||
batch_related_item.append(conv_dict["related_item"])
|
||||
batch_related_entity.append(conv_dict["related_entity"])
|
||||
batch_related_word.append(conv_dict["related_word"])
|
||||
batch_movies.append(conv_dict["item"])
|
||||
batch_conv_id.append(conv_dict["conv_id"])
|
||||
|
||||
res = {
|
||||
"conv_id": batch_conv_id,
|
||||
"related_item": batch_related_item,
|
||||
"related_entity": batch_related_entity,
|
||||
"related_word": batch_related_word,
|
||||
"item": torch.tensor(batch_movies, dtype=torch.long),
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
def conv_process_fn(self, *args, **kwargs):
|
||||
return self.retain_recommender_target()
|
||||
|
||||
def conv_batchify(self, batch):
|
||||
batch_related_tokens = []
|
||||
batch_context_tokens = []
|
||||
|
||||
batch_related_item = []
|
||||
batch_related_entity = []
|
||||
batch_related_word = []
|
||||
|
||||
batch_response = []
|
||||
batch_conv_id = []
|
||||
for conv_dict in batch:
|
||||
batch_related_tokens.append(
|
||||
truncate(conv_dict["tokens"][-1], self.related_truncate, truncate_tail=False)
|
||||
)
|
||||
batch_context_tokens.append(
|
||||
truncate(merge_utt(
|
||||
conv_dict["tokens"],
|
||||
start_token_idx=self.start_token_idx,
|
||||
split_token_idx=self.split_token_idx,
|
||||
final_token_idx=self.end_token_idx
|
||||
), self.context_truncate, truncate_tail=False)
|
||||
)
|
||||
|
||||
batch_related_item.append(conv_dict["item"])
|
||||
batch_related_entity.append(conv_dict["entity"])
|
||||
batch_related_word.append(conv_dict["word"])
|
||||
|
||||
batch_response.append(
|
||||
add_start_end_token_idx(truncate(conv_dict["response"], self.response_truncate - 2),
|
||||
start_token_idx=self.start_token_idx,
|
||||
end_token_idx=self.end_token_idx))
|
||||
batch_conv_id.append(conv_dict["conv_id"])
|
||||
|
||||
res = {
|
||||
"related_tokens": padded_tensor(batch_related_tokens, self.pad_token_idx, pad_tail=False),
|
||||
"context_tokens": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False),
|
||||
"related_item": batch_related_item,
|
||||
"related_entity": batch_related_entity,
|
||||
"related_word": batch_related_word,
|
||||
"response": padded_tensor(batch_response, self.pad_token_idx),
|
||||
"conv_id": batch_conv_id,
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
def policy_batchify(self, *args, **kwargs):
|
||||
pass
|
182
HyCoRec/crslab/data/dataloader/utils.py
Normal file
182
HyCoRec/crslab/data/dataloader/utils.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/10
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/20, 2020/12/15
|
||||
# @Author : Xiaolei Wang, Yuanhang Zhou
|
||||
# @email : wxl1999@foxmail.com, sdzyh002@gmail
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2021/10/06
|
||||
# @Author : Zhipeng Zhao
|
||||
# @Email : oran_official@outlook.com
|
||||
|
||||
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
from typing import List, Union, Optional
|
||||
|
||||
|
||||
def padded_tensor(
|
||||
items: List[Union[List[int], torch.LongTensor]],
|
||||
pad_idx: int = 0,
|
||||
pad_tail: bool = True,
|
||||
max_len: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""Create a padded matrix from an uneven list of lists.
|
||||
|
||||
Returns padded matrix.
|
||||
|
||||
Matrix is right-padded (filled to the right) by default, but can be
|
||||
left padded if the flag is set to True.
|
||||
|
||||
Matrix can also be placed on cuda automatically.
|
||||
|
||||
:param list[iter[int]] items: List of items
|
||||
:param int pad_idx: the value to use for padding
|
||||
:param bool pad_tail:
|
||||
:param int max_len: if None, the max length is the maximum item length
|
||||
|
||||
:returns: padded tensor.
|
||||
:rtype: Tensor[int64]
|
||||
|
||||
"""
|
||||
# number of items
|
||||
n = len(items)
|
||||
# length of each item
|
||||
lens: List[int] = [len(item) for item in items] # type: ignore
|
||||
# max in time dimension
|
||||
t = max(lens) if max_len is None else max_len
|
||||
# if input tensors are empty, we should expand to nulls
|
||||
t = max(t, 1)
|
||||
|
||||
if isinstance(items[0], torch.Tensor):
|
||||
# keep type of input tensors, they may already be cuda ones
|
||||
output = items[0].new(n, t) # type: ignore
|
||||
else:
|
||||
output = torch.LongTensor(n, t) # type: ignore
|
||||
output.fill_(pad_idx)
|
||||
|
||||
for i, (item, length) in enumerate(zip(items, lens)):
|
||||
if length == 0:
|
||||
# skip empty items
|
||||
continue
|
||||
if not isinstance(item, torch.Tensor):
|
||||
# put non-tensors into a tensor
|
||||
item = torch.tensor(item, dtype=torch.long) # type: ignore
|
||||
if pad_tail:
|
||||
# place at beginning
|
||||
output[i, :length] = item
|
||||
else:
|
||||
# place at end
|
||||
output[i, t - length:] = item
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_onehot(data_list, categories) -> torch.Tensor:
|
||||
"""Transform lists of label into one-hot.
|
||||
|
||||
Args:
|
||||
data_list (list of list of int): source data.
|
||||
categories (int): #label class.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: one-hot labels.
|
||||
|
||||
"""
|
||||
onehot_labels = []
|
||||
for label_list in data_list:
|
||||
onehot_label = torch.zeros(categories)
|
||||
for label in label_list:
|
||||
onehot_label[label] = 1.0 / len(label_list)
|
||||
onehot_labels.append(onehot_label)
|
||||
return torch.stack(onehot_labels, dim=0)
|
||||
|
||||
|
||||
def add_start_end_token_idx(vec: list, start_token_idx: int = None, end_token_idx: int = None):
|
||||
"""Can choose to add start token in the beginning and end token in the end.
|
||||
|
||||
Args:
|
||||
vec: source list composed of indexes.
|
||||
start_token_idx: index of start token.
|
||||
end_token_idx: index of end token.
|
||||
|
||||
Returns:
|
||||
list: list added start or end token index.
|
||||
|
||||
"""
|
||||
res = copy(vec)
|
||||
if start_token_idx:
|
||||
res.insert(0, start_token_idx)
|
||||
if end_token_idx:
|
||||
res.append(end_token_idx)
|
||||
return res
|
||||
|
||||
|
||||
def truncate(vec, max_length, truncate_tail=True):
|
||||
"""truncate vec to make its length no more than max length.
|
||||
|
||||
Args:
|
||||
vec (list): source list.
|
||||
max_length (int)
|
||||
truncate_tail (bool, optional): Defaults to True.
|
||||
|
||||
Returns:
|
||||
list: truncated vec.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
return vec
|
||||
if len(vec) <= max_length:
|
||||
return vec
|
||||
if max_length == 0:
|
||||
return []
|
||||
if truncate_tail:
|
||||
return vec[:max_length]
|
||||
else:
|
||||
return vec[-max_length:]
|
||||
|
||||
|
||||
def merge_utt(conversation, start_token_idx=None, split_token_idx=None, keep_split_in_tail=False, final_token_idx=None):
|
||||
"""merge utterances in one conversation.
|
||||
|
||||
Args:
|
||||
conversation (list of list of int): conversation consist of utterances consist of tokens.
|
||||
split_token_idx (int): index of split token. Defaults to None.
|
||||
keep_split_in_tail (bool): split in tail or head. Defaults to False.
|
||||
final_token_idx (int): index of final token. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: tokens of all utterances in one list.
|
||||
|
||||
"""
|
||||
merged_conv = []
|
||||
if start_token_idx:
|
||||
merged_conv.append(start_token_idx)
|
||||
for utt in conversation:
|
||||
for token in utt:
|
||||
merged_conv.append(token)
|
||||
if split_token_idx:
|
||||
merged_conv.append(split_token_idx)
|
||||
if split_token_idx and not keep_split_in_tail:
|
||||
merged_conv = merged_conv[:-1]
|
||||
if final_token_idx:
|
||||
merged_conv.append(final_token_idx)
|
||||
return merged_conv
|
||||
|
||||
def merge_utt_replace(conversation,detect_token=None,replace_token=None,method="in"):
|
||||
if method == 'in':
|
||||
replaced_conv = []
|
||||
for utt in conversation:
|
||||
for token in utt:
|
||||
if detect_token in token:
|
||||
replaced_conv.append(replace_token)
|
||||
else:
|
||||
replaced_conv.append(token)
|
||||
return replaced_conv
|
||||
else:
|
||||
return [token.replace(detect_token,replace_token) for utt in conversation for token in utt]
|
8
HyCoRec/crslab/data/dataset/__init__.py
Normal file
8
HyCoRec/crslab/data/dataset/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .base import BaseDataset
|
||||
from .hredial import HReDialDataset
|
||||
from .htgredial import HTGReDialDataset
|
||||
from .opendialkg import OpenDialKGDataset
|
||||
from .durecdial import DuRecDialDataset
|
||||
|
||||
from .redial import ReDialDataset
|
||||
from .tgredial import TGReDialDataset
|
171
HyCoRec/crslab/data/dataset/base.py
Normal file
171
HyCoRec/crslab/data/dataset/base.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2020/12/13
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
import os
|
||||
import pickle as pkl
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from crslab.download import build
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
"""Abstract class of dataset
|
||||
|
||||
Notes:
|
||||
``'embedding'`` can be specified in config to use pretrained word embedding.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, dpath, resource, restore=False, save=False):
|
||||
"""Download resource, load, process data. Support restore and save processed dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
dpath (str): where to store dataset.
|
||||
resource (dict): version, download file and special token idx of tokenized dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
self.opt = opt
|
||||
self.dpath = dpath
|
||||
|
||||
# download
|
||||
dfile = resource['file']
|
||||
build(dpath, dfile, version=resource['version'])
|
||||
|
||||
if not restore:
|
||||
# load and process
|
||||
train_data, valid_data, test_data, self.vocab = self._load_data()
|
||||
logger.info('[Finish data load]')
|
||||
self.train_data, self.valid_data, self.test_data, self.side_data = self._data_preprocess(train_data, valid_data, test_data)
|
||||
embedding = opt.get('embedding', None)
|
||||
if embedding:
|
||||
self.side_data["embedding"] = np.load(os.path.join(self.dpath, embedding))
|
||||
logger.debug(f'[Load pretrained embedding {embedding}]')
|
||||
logger.info('[Finish data preprocess]')
|
||||
else:
|
||||
self._load_other_data()
|
||||
self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab = self._load_from_restore()
|
||||
|
||||
if save:
|
||||
data = (self.train_data, self.valid_data, self.test_data, self.side_data, self.vocab)
|
||||
self._save_to_one(data)
|
||||
|
||||
@abstractmethod
|
||||
def _load_other_data(self):
|
||||
"""
|
||||
Load review.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _load_data(self):
|
||||
"""Load dataset.
|
||||
|
||||
Returns:
|
||||
(any, any, any, dict):
|
||||
|
||||
raw train, valid and test data.
|
||||
|
||||
vocab: all kinds of useful size, idx and map between token and idx.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
"""Process raw train, valid, test data.
|
||||
|
||||
Args:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
|
||||
Returns:
|
||||
(list of dict, dict):
|
||||
|
||||
train/valid/test_data, each dict is in the following format::
|
||||
|
||||
{
|
||||
'role' (str):
|
||||
'Seeker' or 'Recommender',
|
||||
'user_profile' (list of list of int):
|
||||
id of tokens of sentences of user profile,
|
||||
'context_tokens' (list of list int):
|
||||
token ids of preprocessed contextual dialogs,
|
||||
'response' (list of int):
|
||||
token ids of the ground-truth response,
|
||||
'interaction_history' (list of int):
|
||||
id of items which have interaction of the user in current turn,
|
||||
'context_items' (list of int):
|
||||
item ids mentioned in context,
|
||||
'items' (list of int):
|
||||
item ids mentioned in current turn, we only keep
|
||||
those in entity kg for comparison,
|
||||
'context_entities' (list of int):
|
||||
if necessary, id of entities in context,
|
||||
'context_words' (list of int):
|
||||
if necessary, id of words in context,
|
||||
'context_policy' (list of list of list):
|
||||
policy of each context turn, one turn may have several policies,
|
||||
where first is action and second is keyword,
|
||||
'target' (list): policy of current turn,
|
||||
'final' (list): final goal for current turn
|
||||
}
|
||||
|
||||
side_data, which is in the following format::
|
||||
|
||||
{
|
||||
'entity_kg': {
|
||||
'edge' (list of tuple): (head_entity_id, tail_entity_id, relation_id),
|
||||
'n_relation' (int): number of distinct relations,
|
||||
'entity' (list of str): str of entities, used for entity linking
|
||||
}
|
||||
'word_kg': {
|
||||
'edge' (list of tuple): (head_entity_id, tail_entity_id),
|
||||
'entity' (list of str): str of entities, used for entity linking
|
||||
}
|
||||
'item_entity_ids' (list of int): entity id of each item;
|
||||
}
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def _load_from_restore(self, file_name="all_data.pkl"):
|
||||
"""Restore saved dataset.
|
||||
|
||||
Args:
|
||||
file_name (str): file of saved dataset. Defaults to "all_data.pkl".
|
||||
|
||||
"""
|
||||
if not os.path.exists(os.path.join(self.dpath, file_name)):
|
||||
raise ValueError(f'Saved dataset [{file_name}] does not exist')
|
||||
with open(os.path.join(self.dpath, file_name), 'rb') as f:
|
||||
dataset = pkl.load(f)
|
||||
logger.info(f'Restore dataset from [{file_name}]')
|
||||
return dataset
|
||||
|
||||
def _save_to_one(self, data, file_name="all_data.pkl"):
|
||||
"""Save all processed dataset and vocab into one file.
|
||||
|
||||
Args:
|
||||
data (tuple): all dataset and vocab.
|
||||
file_name (str, optional): file to save dataset. Defaults to "all_data.pkl".
|
||||
|
||||
"""
|
||||
if not os.path.exists(self.dpath):
|
||||
os.makedirs(self.dpath)
|
||||
save_path = os.path.join(self.dpath, file_name)
|
||||
with open(save_path, 'wb') as f:
|
||||
pkl.dump(data, f)
|
||||
logger.info(f'[Save dataset to {file_name}]')
|
1
HyCoRec/crslab/data/dataset/durecdial/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/durecdial/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .durecdial import DuRecDialDataset
|
281
HyCoRec/crslab/data/dataset/durecdial/durecdial.py
Normal file
281
HyCoRec/crslab/data/dataset/durecdial/durecdial.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# @Time : 2020/12/21
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/21, 2021/1/2
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
r"""
|
||||
DuRecDial
|
||||
=========
|
||||
References:
|
||||
Liu, Zeming, et al. `"Towards Conversational Recommendation over Multi-Type Dialogs."`_ in ACL 2020.
|
||||
|
||||
.. _"Towards Conversational Recommendation over Multi-Type Dialogs.":
|
||||
https://www.aclweb.org/anthology/2020.acl-main.98/
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class DuRecDialDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
self.special_token_idx = resource['special_token_idx']
|
||||
self.unk_token_idx = self.special_token_idx['unk']
|
||||
dpath = os.path.join(DATASET_PATH, 'durecdial', tokenize)
|
||||
self.train_review = None
|
||||
self.valid_review = None
|
||||
self.test_review = None
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'word2id': self.word2id,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': self.n_entity,
|
||||
'n_word': self.n_word,
|
||||
}
|
||||
vocab.update(self.special_token_idx)
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# entity kg
|
||||
with open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8') as f:
|
||||
self.entity2id = json.load(f) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
# {head_entity_id: [(relation_id, tail_entity_id)]}
|
||||
self.entity_kg = open(os.path.join(self.dpath, 'entity_subkg.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'entity_subkg.txt')}]")
|
||||
|
||||
# hownet
|
||||
# {concept: concept_id}
|
||||
with open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8') as f:
|
||||
self.word2id = json.load(f)
|
||||
self.n_word = max(self.word2id.values()) + 1
|
||||
# {concept \t relation\t concept}
|
||||
self.word_kg = open(os.path.join(self.dpath, 'hownet_subkg.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet_subkg.txt')}]")
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self._side_data_process()
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _convert_to_id(self, idx, conversation):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
conv_id = conversation.get("conv_id", idx)
|
||||
related_item = []
|
||||
related_entity = []
|
||||
related_word = []
|
||||
for utt in conversation['dialog']:
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
|
||||
|
||||
related_item += item_ids
|
||||
related_entity += entity_ids
|
||||
related_word += word_ids
|
||||
|
||||
if utt["role"] == last_role:
|
||||
augmented_convs[-1]["text"] += text_token_ids
|
||||
augmented_convs[-1]["item"] += item_ids
|
||||
augmented_convs[-1]["entity"] += entity_ids
|
||||
augmented_convs[-1]["word"] += word_ids
|
||||
else:
|
||||
augmented_convs.append({
|
||||
"conv_id": conv_id,
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"item": related_item,
|
||||
"entity": related_entity,
|
||||
"word": related_word
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_items = [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
"conv_id": conv["conv_id"],
|
||||
"role": conv["role"],
|
||||
"tokens": copy(context_tokens),
|
||||
"response": text_tokens,
|
||||
"item": copy(context_items),
|
||||
"entity": copy(context_entities),
|
||||
"word": copy(context_words),
|
||||
"items": movies
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _side_data_process(self):
|
||||
processed_entity_kg = self._entity_kg_process()
|
||||
logger.debug("[Finish entity KG process]")
|
||||
processed_word_kg = self._word_kg_process()
|
||||
logger.debug("[Finish word KG process]")
|
||||
with open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8') as f:
|
||||
item_entity_ids = json.load(f)
|
||||
logger.debug('[Load movie entity ids]')
|
||||
|
||||
side_data = {
|
||||
"entity_kg": processed_entity_kg,
|
||||
"word_kg": processed_word_kg,
|
||||
"item_entity_ids": item_entity_ids,
|
||||
}
|
||||
return side_data
|
||||
|
||||
def _entity_kg_process(self):
|
||||
edge_list = [] # [(entity, entity, relation)]
|
||||
for line in self.entity_kg:
|
||||
triple = line.strip().split('\t')
|
||||
e0 = self.entity2id[triple[0]]
|
||||
e1 = self.entity2id[triple[2]]
|
||||
r = triple[1]
|
||||
edge_list.append((e0, e1, r))
|
||||
edge_list.append((e1, e0, r))
|
||||
edge_list.append((e0, e0, 'SELF_LOOP'))
|
||||
if e1 != e0:
|
||||
edge_list.append((e1, e1, 'SELF_LOOP'))
|
||||
|
||||
relation2id, edges, entities = dict(), set(), set()
|
||||
for h, t, r in edge_list:
|
||||
if r not in relation2id:
|
||||
relation2id[r] = len(relation2id)
|
||||
edges.add((h, t, relation2id[r]))
|
||||
entities.add(self.id2entity[h])
|
||||
entities.add(self.id2entity[t])
|
||||
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'n_relation': len(relation2id),
|
||||
'entity': list(entities)
|
||||
}
|
||||
|
||||
def _word_kg_process(self):
|
||||
edges = set() # {(entity, entity)}
|
||||
entities = set()
|
||||
for line in self.word_kg:
|
||||
triple = line.strip().split('\t')
|
||||
entities.add(triple[0])
|
||||
entities.add(triple[2])
|
||||
e0 = self.word2id[triple[0]]
|
||||
e1 = self.word2id[triple[2]]
|
||||
edges.add((e0, e1))
|
||||
edges.add((e1, e0))
|
||||
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'entity': list(entities)
|
||||
}
|
70
HyCoRec/crslab/data/dataset/durecdial/resources.py
Normal file
70
HyCoRec/crslab/data/dataset/durecdial/resources.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'jieba': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQ5u_Mos1JBFo4MAN8DinUQB7dPWuTsIHGjjvMougLfYaQ?download=1',
|
||||
'durecdial_jieba.zip',
|
||||
'c2d24f7d262e24e45a9105161b5eb15057c96c291edb3a2a7b23c9c637fd3813',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETGpJYjEM9tFhze2VfD33cQBDwa7zq07EUr94zoPZvMPtA?download=1',
|
||||
'durecdial_bert.zip',
|
||||
'0126803aee62a5a4d624d8401814c67bee724ad0af5226d421318ac4eec496f5'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
'pad_topic': 0
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETxJk-3Kd6tDgFvPhLo9bLUBfVsVZlF80QCnGFcVgusdJg?download=1',
|
||||
'durecdial_gpt2.zip',
|
||||
'a7a93292b4e4b8a5e5a2c644f85740e625e04fbd3da76c655150c00f97d405e4'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'cls': 101,
|
||||
'sep': 102,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
'pad_topic': 0,
|
||||
},
|
||||
}
|
||||
}
|
1
HyCoRec/crslab/data/dataset/hredial/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/hredial/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .hredial import HReDialDataset
|
209
HyCoRec/crslab/data/dataset/hredial/hredial.py
Normal file
209
HyCoRec/crslab/data/dataset/hredial/hredial.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
|
||||
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
|
||||
|
||||
r"""
|
||||
ReDial
|
||||
======
|
||||
References:
|
||||
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
|
||||
|
||||
.. _`"Towards deep conversational recommendations."`:
|
||||
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle as pkl
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class HReDialDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""Specify tokenized resource and init base dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
dpath = os.path.join(DATASET_PATH, "hredial", tokenize)
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': self.n_entity
|
||||
}
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
# load train/valid/test data
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# edge extension data
|
||||
self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8'))
|
||||
# dbpedia
|
||||
self.entity2id = json.load(
|
||||
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb'))
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self.side_data
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
# 将文本、电影、实体信息转换为序号
|
||||
def _convert_to_id(self, conversation):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
conv_id = conversation["conv_id"]
|
||||
related_item = []
|
||||
related_entity = []
|
||||
related_word = []
|
||||
for utt in conversation['dialog']:
|
||||
self.unk_token_idx = 3
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind]
|
||||
|
||||
related_item += item_ids
|
||||
related_entity += entity_ids
|
||||
related_word += word_ids
|
||||
|
||||
if utt["role"] == last_role:
|
||||
augmented_convs[-1]["text"] += text_token_ids
|
||||
augmented_convs[-1]["item"] += item_ids
|
||||
augmented_convs[-1]["entity"] += entity_ids
|
||||
augmented_convs[-1]["word"] += word_ids
|
||||
else:
|
||||
augmented_convs.append({
|
||||
"conv_id": conv_id,
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"item": related_item,
|
||||
"entity": related_entity,
|
||||
"word": related_word
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_items = [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
"conv_id": conv["conv_id"],
|
||||
"role": conv["role"],
|
||||
"tokens": context_tokens,
|
||||
"response": text_tokens,
|
||||
"item": context_items,
|
||||
"entity": context_entities,
|
||||
"word": context_words,
|
||||
"items": movies
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
66
HyCoRec/crslab/data/dataset/hredial/resources.py
Normal file
66
HyCoRec/crslab/data/dataset/hredial/resources.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/1
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'nltk': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
|
||||
'redial_nltk.zip',
|
||||
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
|
||||
'redial_bert.zip',
|
||||
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
|
||||
'redial_gpt2.zip',
|
||||
'37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'sent_split': 4,
|
||||
'word_split': 5,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
}
|
||||
}
|
1
HyCoRec/crslab/data/dataset/htgredial/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/htgredial/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .htgredial import HTGReDialDataset
|
210
HyCoRec/crslab/data/dataset/htgredial/htgredial.py
Normal file
210
HyCoRec/crslab/data/dataset/htgredial/htgredial.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
|
||||
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
|
||||
|
||||
r"""
|
||||
ReDial
|
||||
======
|
||||
References:
|
||||
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
|
||||
|
||||
.. _`"Towards deep conversational recommendations."`:
|
||||
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle as pkl
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class HTGReDialDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""Specify tokenized resource and init base dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
dpath = os.path.join(DATASET_PATH, "htgredial", tokenize)
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': self.n_entity
|
||||
}
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
# load train/valid/test data
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# edge extension data
|
||||
self.conv2items = json.load(open(os.path.join(self.dpath, 'conv2items.json'), 'r', encoding='utf-8'))
|
||||
# hypergraph
|
||||
self.user2hypergraph = json.load(open(os.path.join(self.dpath, 'user2hypergraph.json'), 'r', encoding='utf-8'))
|
||||
# dbpedia
|
||||
self.entity2id = json.load(
|
||||
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
self.side_data = pkl.load(open(os.path.join(self.dpath, 'side_data.pkl'), 'rb'))
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self.side_data
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._convert_to_id(conv) for convs in tqdm(raw_data) for conv in convs]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
# 将文本、电影、实体信息转换为序号
|
||||
def _convert_to_id(self, conversation):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
conv_id = conversation["conv_id"]
|
||||
related_item = []
|
||||
related_entity = []
|
||||
related_word = []
|
||||
for utt in conversation['dialog']:
|
||||
self.unk_token_idx = 3
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
item_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.tok2ind[word] for word in utt['text'] if word in self.tok2ind]
|
||||
|
||||
related_item += item_ids
|
||||
related_entity += entity_ids
|
||||
related_word += word_ids
|
||||
|
||||
if utt["role"] == last_role:
|
||||
augmented_convs[-1]["text"] += text_token_ids
|
||||
augmented_convs[-1]["item"] += item_ids
|
||||
augmented_convs[-1]["entity"] += entity_ids
|
||||
augmented_convs[-1]["word"] += word_ids
|
||||
else:
|
||||
augmented_convs.append({
|
||||
"conv_id": conv_id,
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"item": related_item,
|
||||
"entity": related_entity,
|
||||
"word": related_word
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_items = [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
"conv_id": conv["conv_id"],
|
||||
"role": conv["role"],
|
||||
"tokens": copy(context_tokens),
|
||||
"response": text_tokens,
|
||||
"item": copy(context_items),
|
||||
"entity": copy(context_entities),
|
||||
"word": copy(context_words),
|
||||
"items": movies
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
66
HyCoRec/crslab/data/dataset/htgredial/resources.py
Normal file
66
HyCoRec/crslab/data/dataset/htgredial/resources.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/1
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'pkuseg': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
|
||||
'redial_nltk.zip',
|
||||
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
|
||||
'redial_bert.zip',
|
||||
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
|
||||
'redial_gpt2.zip',
|
||||
'37b1a64032241903a37b5e014ee36e50d09f7e4a849058688e9af52027a3ac36',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'sent_split': 4,
|
||||
'word_split': 5,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
}
|
||||
}
|
1
HyCoRec/crslab/data/dataset/opendialkg/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/opendialkg/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .opendialkg import OpenDialKGDataset
|
286
HyCoRec/crslab/data/dataset/opendialkg/opendialkg.py
Normal file
286
HyCoRec/crslab/data/dataset/opendialkg/opendialkg.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# @Time : 2020/12/19
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/20, 2021/1/2
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
r"""
|
||||
OpenDialKG
|
||||
==========
|
||||
References:
|
||||
Moon, Seungwhan, et al. `"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`_ in ACL 2019.
|
||||
|
||||
.. _`"Opendialkg: Explainable conversational reasoning with attention-based walks over knowledge graphs."`:
|
||||
https://www.aclweb.org/anthology/P19-1081/
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class OpenDialKGDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""Specify tokenized resource and init base dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
self.special_token_idx = resource['special_token_idx']
|
||||
self.unk_token_idx = self.special_token_idx['unk']
|
||||
dpath = os.path.join(DATASET_PATH, 'opendialkg', tokenize)
|
||||
self.train_review = None
|
||||
self.valid_review = None
|
||||
self.test_review = None
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'word2id': self.word2id,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': self.n_entity,
|
||||
'n_word': self.n_word,
|
||||
}
|
||||
vocab.update(self.special_token_idx)
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
# load train/valid/test data
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# opendialkg
|
||||
self.entity2id = json.load(
|
||||
open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
# {head_entity_id: [(relation_id, tail_entity_id)]}
|
||||
self.entity_kg = open(os.path.join(self.dpath, 'opendialkg_subkg.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'opendialkg_subkg.json')} and {os.path.join(self.dpath, 'opendialkg_triples.txt')}]")
|
||||
|
||||
# conceptnet
|
||||
# {concept: concept_id}
|
||||
self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))
|
||||
self.n_word = max(self.word2id.values()) + 1
|
||||
# {concept \t relation\t concept}
|
||||
self.word_kg = open(os.path.join(self.dpath, 'concept_subkg.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'concept_subkg.txt')}]")
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self._side_data_process()
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._convert_to_id(idx, conversation) for idx, conversation in enumerate(tqdm(raw_data))]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _convert_to_id(self, idx, conversation):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
conv_id = conversation.get("conv_id", idx)
|
||||
related_item = []
|
||||
related_entity = []
|
||||
related_word = []
|
||||
for utt in conversation['dialog']:
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
item_ids = [self.entity2id[movie] for movie in utt['item'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
|
||||
|
||||
related_item += item_ids
|
||||
related_entity += entity_ids
|
||||
related_word += word_ids
|
||||
|
||||
if utt["role"] == last_role:
|
||||
augmented_convs[-1]["text"] += text_token_ids
|
||||
augmented_convs[-1]["item"] += item_ids
|
||||
augmented_convs[-1]["entity"] += entity_ids
|
||||
augmented_convs[-1]["word"] += word_ids
|
||||
else:
|
||||
augmented_convs.append({
|
||||
"conv_id": conv_id,
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"item": related_item,
|
||||
"entity": related_entity,
|
||||
"word": related_word
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_items = [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["item"], conv["word"]
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
"conv_id": conv["conv_id"],
|
||||
"role": conv["role"],
|
||||
"tokens": copy(context_tokens),
|
||||
"response": text_tokens,
|
||||
"item": copy(context_items),
|
||||
"entity": copy(context_entities),
|
||||
"word": copy(context_words),
|
||||
"items": movies
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _side_data_process(self):
|
||||
processed_entity_kg = self._entity_kg_process()
|
||||
logger.debug("[Finish entity KG process]")
|
||||
processed_word_kg = self._word_kg_process()
|
||||
logger.debug("[Finish word KG process]")
|
||||
item_entity_ids = json.load(open(os.path.join(self.dpath, 'item_ids.json'), 'r', encoding='utf-8'))
|
||||
logger.debug('[Load item entity ids]')
|
||||
|
||||
side_data = {
|
||||
"entity_kg": processed_entity_kg,
|
||||
"word_kg": processed_word_kg,
|
||||
"item_entity_ids": item_entity_ids,
|
||||
}
|
||||
return side_data
|
||||
|
||||
def _entity_kg_process(self):
|
||||
edge_list = [] # [(entity, entity, relation)]
|
||||
for line in self.entity_kg:
|
||||
triple = line.strip().split('\t')
|
||||
if len(triple) != 3 or triple[0] not in self.entity2id or triple[2] not in self.entity2id:
|
||||
continue
|
||||
e0 = self.entity2id[triple[0]]
|
||||
e1 = self.entity2id[triple[2]]
|
||||
r = triple[1]
|
||||
edge_list.append((e0, e1, r))
|
||||
# edge_list.append((e1, e0, r))
|
||||
edge_list.append((e0, e0, 'SELF_LOOP'))
|
||||
if e1 != e0:
|
||||
edge_list.append((e1, e1, 'SELF_LOOP'))
|
||||
|
||||
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
|
||||
for h, t, r in edge_list:
|
||||
relation_cnt[r] += 1
|
||||
for h, t, r in edge_list:
|
||||
if relation_cnt[r] > 20000:
|
||||
if r not in relation2id:
|
||||
relation2id[r] = len(relation2id)
|
||||
edges.add((h, t, relation2id[r]))
|
||||
entities.add(self.id2entity[h])
|
||||
entities.add(self.id2entity[t])
|
||||
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'n_relation': len(relation2id),
|
||||
'entity': list(entities)
|
||||
}
|
||||
|
||||
def _word_kg_process(self):
|
||||
edges = set() # {(entity, entity)}
|
||||
entities = set()
|
||||
for line in self.word_kg:
|
||||
triple = line.strip().split('\t')
|
||||
entities.add(triple[0])
|
||||
entities.add(triple[2])
|
||||
e0 = self.word2id[triple[0]]
|
||||
e1 = self.word2id[triple[2]]
|
||||
edges.add((e0, e1))
|
||||
edges.add((e1, e0))
|
||||
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'entity': list(entities)
|
||||
}
|
66
HyCoRec/crslab/data/dataset/opendialkg/resources.py
Normal file
66
HyCoRec/crslab/data/dataset/opendialkg/resources.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/21
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'nltk': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ESB7grlJlehKv7XmYgMgq5AB85LhRu_rSW93_kL8Arfrhw?download=1',
|
||||
'opendialkg_nltk.zip',
|
||||
'6487f251ac74911e35bec690469fba52a7df14908575229b63ee30f63885c32f'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EWab0Pzgb4JOiecUHZxVaEEBRDBMoeLZDlStrr7YxentRA?download=1',
|
||||
'opendialkg_bert.zip',
|
||||
'0ec3ff45214fac9af570744e9b5893f224aab931744c70b7eeba7e1df13a4f07'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdE5iyKIoAhLvCwwBN4MdJwB2wsDADxJCs_KRaH-G3b7kg?download=1',
|
||||
'opendialkg_gpt2.zip',
|
||||
'dec20b01247cfae733988d7f7bfd1c99f4bb8ba7786b3fdaede5c9a618c6d71e'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'sent_split': 4,
|
||||
'word_split': 5,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
}
|
||||
}
|
1
HyCoRec/crslab/data/dataset/redial/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/redial/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .redial import ReDialDataset
|
269
HyCoRec/crslab/data/dataset/redial/redial.py
Normal file
269
HyCoRec/crslab/data/dataset/redial/redial.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/23, 2021/1/3, 2020/12/19
|
||||
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail
|
||||
|
||||
r"""
|
||||
ReDial
|
||||
======
|
||||
References:
|
||||
Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS 2018.
|
||||
|
||||
.. _`"Towards deep conversational recommendations."`:
|
||||
https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class ReDialDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""Specify tokenized resource and init base dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
self.special_token_idx = resource['special_token_idx']
|
||||
self.unk_token_idx = self.special_token_idx['unk']
|
||||
dpath = os.path.join(DATASET_PATH, "redial", tokenize)
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'word2id': self.word2id,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_entity': self.n_entity,
|
||||
'n_word': self.n_word,
|
||||
}
|
||||
vocab.update(self.special_token_idx)
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
# load train/valid/test data
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# dbpedia
|
||||
self.entity2id = json.load(
|
||||
open(os.path.join(self.dpath, 'entity2id.json'), 'r', encoding='utf-8')) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
# {head_entity_id: [(relation_id, tail_entity_id)]}
|
||||
self.entity_kg = json.load(open(os.path.join(self.dpath, 'dbpedia_subkg.json'), 'r', encoding='utf-8'))
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'dbpedia_subkg.json')}]")
|
||||
|
||||
# conceptNet
|
||||
# {concept: concept_id}
|
||||
self.word2id = json.load(open(os.path.join(self.dpath, 'concept2id.json'), 'r', encoding='utf-8'))
|
||||
self.n_word = max(self.word2id.values()) + 1
|
||||
# {relation\t concept \t concept}
|
||||
self.word_kg = open(os.path.join(self.dpath, 'conceptnet_subkg.txt'), 'r', encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'concept2id.json')} and {os.path.join(self.dpath, 'conceptnet_subkg.txt')}]")
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self._side_data_process()
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._merge_conv_data(conversation["dialog"]) for conversation in tqdm(raw_data)]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _merge_conv_data(self, dialog):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
for utt in dialog:
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
movie_ids = [self.entity2id[movie] for movie in utt['movies'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
|
||||
|
||||
if utt["role"] == last_role:
|
||||
augmented_convs[-1]["text"] += text_token_ids
|
||||
augmented_convs[-1]["movie"] += movie_ids
|
||||
augmented_convs[-1]["entity"] += entity_ids
|
||||
augmented_convs[-1]["word"] += word_ids
|
||||
else:
|
||||
augmented_convs.append({
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"entity": entity_ids,
|
||||
"movie": movie_ids,
|
||||
"word": word_ids
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_items = [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words = conv["text"], conv["entity"], conv["movie"], conv["word"]
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
"conv_id": i,
|
||||
"role": conv['role'],
|
||||
"tokens": copy(context_tokens),
|
||||
"response": text_tokens,
|
||||
"entity": copy(context_entities),
|
||||
"word": copy(context_words),
|
||||
"item": copy(context_items),
|
||||
"items": movies,
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _side_data_process(self):
|
||||
processed_entity_kg = self._entity_kg_process()
|
||||
logger.debug("[Finish entity KG process]")
|
||||
processed_word_kg = self._word_kg_process()
|
||||
logger.debug("[Finish word KG process]")
|
||||
movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))
|
||||
logger.debug('[Load movie entity ids]')
|
||||
|
||||
side_data = {
|
||||
"entity_kg": processed_entity_kg,
|
||||
"word_kg": processed_word_kg,
|
||||
"item_entity_ids": movie_entity_ids,
|
||||
}
|
||||
return side_data
|
||||
|
||||
def _entity_kg_process(self, SELF_LOOP_ID=185):
|
||||
edge_list = [] # [(entity, entity, relation)]
|
||||
for entity in range(self.n_entity):
|
||||
if str(entity) not in self.entity_kg:
|
||||
continue
|
||||
edge_list.append((entity, entity, SELF_LOOP_ID)) # add self loop
|
||||
for tail_and_relation in self.entity_kg[str(entity)]:
|
||||
if entity != tail_and_relation[1] and tail_and_relation[0] != SELF_LOOP_ID:
|
||||
edge_list.append((entity, tail_and_relation[1], tail_and_relation[0]))
|
||||
edge_list.append((tail_and_relation[1], entity, tail_and_relation[0]))
|
||||
|
||||
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
|
||||
for h, t, r in edge_list:
|
||||
relation_cnt[r] += 1
|
||||
for h, t, r in edge_list:
|
||||
if relation_cnt[r] > 1000:
|
||||
if r not in relation2id:
|
||||
relation2id[r] = len(relation2id)
|
||||
edges.add((h, t, relation2id[r]))
|
||||
entities.add(self.id2entity[h])
|
||||
entities.add(self.id2entity[t])
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'n_relation': len(relation2id),
|
||||
'entity': list(entities)
|
||||
}
|
||||
|
||||
def _word_kg_process(self):
|
||||
edges = set() # {(entity, entity)}
|
||||
entities = set()
|
||||
for line in self.word_kg:
|
||||
kg = line.strip().split('\t')
|
||||
entities.add(kg[1].split('/')[0])
|
||||
entities.add(kg[2].split('/')[0])
|
||||
e0 = self.word2id[kg[1].split('/')[0]]
|
||||
e1 = self.word2id[kg[2].split('/')[0]]
|
||||
edges.add((e0, e1))
|
||||
edges.add((e1, e0))
|
||||
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'entity': list(entities)
|
||||
}
|
66
HyCoRec/crslab/data/dataset/redial/resources.py
Normal file
66
HyCoRec/crslab/data/dataset/redial/resources.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/1
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'nltk': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EdVnNcteOkpAkLdNL-ejvAABPieUd8jIty3r1jcdJvGLzw?download=1',
|
||||
'redial_nltk.zip',
|
||||
'01dc2ebf15a0988a92112daa7015ada3e95d855e80cc1474037a86e536de3424',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EXe_sjFhfqpJoTbNcoUPJf8Bl_4U-lnduct0z8Dw5HVCPw?download=1',
|
||||
'redial_bert.zip',
|
||||
'fb55516c22acfd3ba073e05101415568ed3398c86ff56792f82426b9258c92fd',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.31',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EQHOlW2m6mFEqHgt94PfoLsBbmQQeKQEOMyL1lLEHz7LvA?download=1',
|
||||
'redial_gpt2.zip',
|
||||
'15661f1cb126210a09e30228e9477cf57bbec42140d2b1029cc50489beff4eb8',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': -100,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'sent_split': 4,
|
||||
'word_split': 5,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0
|
||||
},
|
||||
}
|
||||
}
|
1
HyCoRec/crslab/data/dataset/tgredial/__init__.py
Normal file
1
HyCoRec/crslab/data/dataset/tgredial/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .tgredial import TGReDialDataset
|
71
HyCoRec/crslab/data/dataset/tgredial/resources.py
Normal file
71
HyCoRec/crslab/data/dataset/tgredial/resources.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/4
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'pkuseg': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/Ee7FleGfEStCimV4XRKvo-kBR8ABdPKo0g_XqgLJPxP6tg?download=1',
|
||||
'tgredial_pkuseg.zip',
|
||||
'8b7e23205778db4baa012eeb129cf8d26f4871ae98cdfe81fde6adc27a73a8d6',
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 1,
|
||||
'end': 2,
|
||||
'unk': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
'pad_topic': 0
|
||||
},
|
||||
},
|
||||
'bert': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/ETC9vIeFtOdElXL10Hbh4L0BGm20-lckCJ3a4u7VFCzpIg?download=1',
|
||||
'tgredial_bert.zip',
|
||||
'd40f7072173c1dc49d4a3125f9985aaf0bd0801d7b437348ece9a894f485193b'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
'pad_topic': 0
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'version': '0.3',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pkueducn_onmicrosoft_com/EcVEcxrDMF1BrbOUD8jEXt4BJeCzUjbNFL6m6UY5W3Hm3g?download=1',
|
||||
'tgredial_gpt2.zip',
|
||||
'2077f137b6a11c2fd523ca63b06e75cc19411cd515b7d5b997704d9e81778df9'
|
||||
),
|
||||
'special_token_idx': {
|
||||
'pad': 0,
|
||||
'start': 101,
|
||||
'end': 102,
|
||||
'unk': 100,
|
||||
'cls': 101,
|
||||
'sep': 102,
|
||||
'sent_split': 2,
|
||||
'word_split': 3,
|
||||
'pad_entity': 0,
|
||||
'pad_word': 0,
|
||||
'pad_topic': 0,
|
||||
},
|
||||
}
|
||||
}
|
343
HyCoRec/crslab/data/dataset/tgredial/tgredial.py
Normal file
343
HyCoRec/crslab/data/dataset/tgredial/tgredial.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# @Time : 2020/12/4
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/6, 2021/1/2, 2020/12/19
|
||||
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
|
||||
# @Email : francis_kun_zhou@163.com, sdzyh002@gmail
|
||||
|
||||
r"""
|
||||
TGReDial
|
||||
========
|
||||
References:
|
||||
Zhou, Kun, et al. `"Towards Topic-Guided Conversational Recommender System."`_ in COLING 2020.
|
||||
|
||||
.. _`"Towards Topic-Guided Conversational Recommender System."`:
|
||||
https://www.aclweb.org/anthology/2020.coling-main.365/
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from crslab.config import DATASET_PATH
|
||||
from crslab.data.dataset.base import BaseDataset
|
||||
from .resources import resources
|
||||
|
||||
|
||||
class TGReDialDataset(BaseDataset):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
train_data: train dataset.
|
||||
valid_data: valid dataset.
|
||||
test_data: test dataset.
|
||||
vocab (dict): ::
|
||||
|
||||
{
|
||||
'tok2ind': map from token to index,
|
||||
'ind2tok': map from index to token,
|
||||
'topic2ind': map from topic to index,
|
||||
'ind2topic': map from index to topic,
|
||||
'entity2id': map from entity to index,
|
||||
'id2entity': map from index to entity,
|
||||
'word2id': map from word to index,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_topic': len(self.topic2ind) + 1,
|
||||
'n_entity': max(self.entity2id.values()) + 1,
|
||||
'n_word': max(self.word2id.values()) + 1,
|
||||
}
|
||||
|
||||
Notes:
|
||||
``'unk'`` and ``'pad_topic'`` must be specified in ``'special_token_idx'`` in ``resources.py``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, tokenize, restore=False, save=False):
|
||||
"""Specify tokenized resource and init base dataset.
|
||||
|
||||
Args:
|
||||
opt (Config or dict): config for dataset or the whole system.
|
||||
tokenize (str): how to tokenize dataset.
|
||||
restore (bool): whether to restore saved dataset which has been processed. Defaults to False.
|
||||
save (bool): whether to save dataset after processing. Defaults to False.
|
||||
|
||||
"""
|
||||
resource = resources[tokenize]
|
||||
self.special_token_idx = resource['special_token_idx']
|
||||
self.unk_token_idx = self.special_token_idx['unk']
|
||||
self.pad_topic_idx = self.special_token_idx['pad_topic']
|
||||
dpath = os.path.join(DATASET_PATH, 'tgredial', tokenize)
|
||||
self.replace_token = opt.get('replace_token',None)
|
||||
self.replace_token_idx = opt.get('replace_token_idx',None)
|
||||
super().__init__(opt, dpath, resource, restore, save)
|
||||
if self.replace_token:
|
||||
if self.replace_token_idx:
|
||||
self.side_data["embedding"][self.replace_token_idx] = self.side_data['embedding'][0]
|
||||
else:
|
||||
self.side_data["embedding"] = np.insert(self.side_data["embedding"],len(self.side_data["embedding"]),self.side_data['embedding'][0],axis=0)
|
||||
|
||||
|
||||
def _load_data(self):
|
||||
train_data, valid_data, test_data = self._load_raw_data()
|
||||
self._load_vocab()
|
||||
self._load_other_data()
|
||||
|
||||
vocab = {
|
||||
'tok2ind': self.tok2ind,
|
||||
'ind2tok': self.ind2tok,
|
||||
'topic2ind': self.topic2ind,
|
||||
'ind2topic': self.ind2topic,
|
||||
'entity2id': self.entity2id,
|
||||
'id2entity': self.id2entity,
|
||||
'word2id': self.word2id,
|
||||
'vocab_size': len(self.tok2ind),
|
||||
'n_topic': len(self.topic2ind) + 1,
|
||||
'n_entity': self.n_entity,
|
||||
'n_word': self.n_word,
|
||||
}
|
||||
vocab.update(self.special_token_idx)
|
||||
|
||||
return train_data, valid_data, test_data, vocab
|
||||
|
||||
def _load_raw_data(self):
|
||||
# load train/valid/test data
|
||||
with open(os.path.join(self.dpath, 'train_data.json'), 'r', encoding='utf-8') as f:
|
||||
train_data = json.load(f)
|
||||
logger.debug(f"[Load train data from {os.path.join(self.dpath, 'train_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'valid_data.json'), 'r', encoding='utf-8') as f:
|
||||
valid_data = json.load(f)
|
||||
logger.debug(f"[Load valid data from {os.path.join(self.dpath, 'valid_data.json')}]")
|
||||
with open(os.path.join(self.dpath, 'test_data.json'), 'r', encoding='utf-8') as f:
|
||||
test_data = json.load(f)
|
||||
logger.debug(f"[Load test data from {os.path.join(self.dpath, 'test_data.json')}]")
|
||||
|
||||
return train_data, valid_data, test_data
|
||||
|
||||
def _load_vocab(self):
|
||||
self.tok2ind = json.load(open(os.path.join(self.dpath, 'token2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2tok = {idx: word for word, idx in self.tok2ind.items()}
|
||||
# add special tokens
|
||||
if self.replace_token:
|
||||
if self.replace_token not in self.tok2ind:
|
||||
if self.replace_token_idx:
|
||||
self.ind2tok[self.replace_token_idx] = self.replace_token
|
||||
self.tok2ind[self.replace_token] = self.replace_token_idx
|
||||
self.special_token_idx[self.replace_token] = self.replace_token_idx
|
||||
else:
|
||||
self.ind2tok[len(self.tok2ind)] = self.replace_token
|
||||
self.tok2ind[self.replace_token] = len(self.tok2ind)
|
||||
self.special_token_idx[self.replace_token] = len(self.tok2ind)-1
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'token2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.tok2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2tok)}]")
|
||||
|
||||
self.topic2ind = json.load(open(os.path.join(self.dpath, 'topic2id.json'), 'r', encoding='utf-8'))
|
||||
self.ind2topic = {idx: word for word, idx in self.topic2ind.items()}
|
||||
|
||||
logger.debug(f"[Load vocab from {os.path.join(self.dpath, 'topic2id.json')}]")
|
||||
logger.debug(f"[The size of token2index dictionary is {len(self.topic2ind)}]")
|
||||
logger.debug(f"[The size of index2token dictionary is {len(self.ind2topic)}]")
|
||||
|
||||
def _load_other_data(self):
|
||||
# cn-dbpedia
|
||||
self.entity2id = json.load(
|
||||
open(os.path.join(self.dpath, 'entity2id.json'), encoding='utf-8')) # {entity: entity_id}
|
||||
self.id2entity = {idx: entity for entity, idx in self.entity2id.items()}
|
||||
self.n_entity = max(self.entity2id.values()) + 1
|
||||
# {head_entity_id: [(relation_id, tail_entity_id)]}
|
||||
self.entity_kg = open(os.path.join(self.dpath, 'cn-dbpedia.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load entity dictionary and KG from {os.path.join(self.dpath, 'entity2id.json')} and {os.path.join(self.dpath, 'cn-dbpedia.txt')}]")
|
||||
|
||||
# hownet
|
||||
# {concept: concept_id}
|
||||
self.word2id = json.load(open(os.path.join(self.dpath, 'word2id.json'), 'r', encoding='utf-8'))
|
||||
self.n_word = max(self.word2id.values()) + 1
|
||||
# {relation\t concept \t concept}
|
||||
self.word_kg = open(os.path.join(self.dpath, 'hownet.txt'), encoding='utf-8')
|
||||
logger.debug(
|
||||
f"[Load word dictionary and KG from {os.path.join(self.dpath, 'word2id.json')} and {os.path.join(self.dpath, 'hownet.txt')}]")
|
||||
|
||||
# user interaction history dictionary
|
||||
self.conv2history = json.load(open(os.path.join(self.dpath, 'user2history.json'), 'r', encoding='utf-8'))
|
||||
logger.debug(f"[Load user interaction history from {os.path.join(self.dpath, 'user2history.json')}]")
|
||||
|
||||
# user profile
|
||||
self.user2profile = json.load(open(os.path.join(self.dpath, 'user2profile.json'), 'r', encoding='utf-8'))
|
||||
logger.debug(f"[Load user profile from {os.path.join(self.dpath, 'user2profile.json')}")
|
||||
|
||||
|
||||
def _data_preprocess(self, train_data, valid_data, test_data):
|
||||
processed_train_data = self._raw_data_process(train_data)
|
||||
logger.debug("[Finish train data process]")
|
||||
processed_valid_data = self._raw_data_process(valid_data)
|
||||
logger.debug("[Finish valid data process]")
|
||||
processed_test_data = self._raw_data_process(test_data)
|
||||
logger.debug("[Finish test data process]")
|
||||
processed_side_data = self._side_data_process()
|
||||
logger.debug("[Finish side data process]")
|
||||
return processed_train_data, processed_valid_data, processed_test_data, processed_side_data
|
||||
|
||||
def _raw_data_process(self, raw_data):
|
||||
augmented_convs = [self._convert_to_id(conversation) for conversation in tqdm(raw_data)]
|
||||
augmented_conv_dicts = []
|
||||
for conv in tqdm(augmented_convs):
|
||||
augmented_conv_dicts.extend(self._augment_and_add(conv))
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _convert_to_id(self, conversation):
|
||||
augmented_convs = []
|
||||
last_role = None
|
||||
for utt in conversation['messages']:
|
||||
assert utt['role'] != last_role
|
||||
# change movies into slots
|
||||
if self.replace_token:
|
||||
if len(utt['movie']) != 0:
|
||||
while '《' in utt['text'] :
|
||||
begin = utt['text'].index("《")
|
||||
end = utt['text'].index("》")
|
||||
utt['text'] = utt['text'][:begin] + [self.replace_token] + utt['text'][end+1:]
|
||||
text_token_ids = [self.tok2ind.get(word, self.unk_token_idx) for word in utt["text"]]
|
||||
movie_ids = [self.entity2id[movie] for movie in utt['movie'] if movie in self.entity2id]
|
||||
entity_ids = [self.entity2id[entity] for entity in utt['entity'] if entity in self.entity2id]
|
||||
word_ids = [self.word2id[word] for word in utt['word'] if word in self.word2id]
|
||||
policy = []
|
||||
for action, kw in zip(utt['target'][1::2], utt['target'][2::2]):
|
||||
if kw is None or action == '推荐电影':
|
||||
continue
|
||||
if isinstance(kw, str):
|
||||
kw = [kw]
|
||||
kw = [self.topic2ind.get(k, self.pad_topic_idx) for k in kw]
|
||||
policy.append([action, kw])
|
||||
final_kws = [self.topic2ind[kw] if kw is not None else self.pad_topic_idx for kw in utt['final'][1]]
|
||||
final = [utt['final'][0], final_kws]
|
||||
conv_utt_id = str(conversation['conv_id']) + '/' + str(utt['local_id'])
|
||||
interaction_history = self.conv2history.get(conv_utt_id, [])
|
||||
user_profile = self.user2profile[conversation['user_id']]
|
||||
user_profile = [[self.tok2ind.get(token, self.unk_token_idx) for token in sent] for sent in user_profile]
|
||||
|
||||
augmented_convs.append({
|
||||
"role": utt["role"],
|
||||
"text": text_token_ids,
|
||||
"entity": entity_ids,
|
||||
"movie": movie_ids,
|
||||
"word": text_token_ids,
|
||||
'policy': policy,
|
||||
'final': final,
|
||||
'interaction_history': interaction_history,
|
||||
'user_profile': user_profile
|
||||
})
|
||||
last_role = utt["role"]
|
||||
|
||||
return augmented_convs
|
||||
|
||||
def _augment_and_add(self, raw_conv_dict):
|
||||
augmented_conv_dicts = []
|
||||
context_tokens, context_entities, context_words, context_policy, context_items = [], [], [], [], []
|
||||
entity_set, word_set = set(), set()
|
||||
for i, conv in enumerate(raw_conv_dict):
|
||||
text_tokens, entities, movies, words, policies = conv["text"], conv["entity"], conv["movie"], conv["word"], \
|
||||
conv['policy']
|
||||
if self.replace_token is not None:
|
||||
if text_tokens.count(30000) != len(movies):
|
||||
continue # the number of slots doesn't equal to the number of movies
|
||||
|
||||
if len(context_tokens) > 0:
|
||||
conv_dict = {
|
||||
'conv_id': i,
|
||||
'role': conv['role'],
|
||||
'user_profile': conv['user_profile'],
|
||||
"tokens": copy(context_tokens),
|
||||
"response": text_tokens,
|
||||
"entity": copy(context_entities),
|
||||
"word": copy(context_words),
|
||||
'interaction_history': conv['interaction_history'],
|
||||
'item': copy(context_items),
|
||||
"items": movies,
|
||||
'context_policy': copy(context_policy),
|
||||
'target': policies,
|
||||
'final': conv['final'],
|
||||
}
|
||||
augmented_conv_dicts.append(conv_dict)
|
||||
|
||||
context_tokens.append(text_tokens)
|
||||
context_policy.append(policies)
|
||||
context_items += movies
|
||||
for entity in entities + movies:
|
||||
if entity not in entity_set:
|
||||
entity_set.add(entity)
|
||||
context_entities.append(entity)
|
||||
for word in words:
|
||||
if word not in word_set:
|
||||
word_set.add(word)
|
||||
context_words.append(word)
|
||||
|
||||
return augmented_conv_dicts
|
||||
|
||||
def _side_data_process(self):
|
||||
processed_entity_kg = self._entity_kg_process()
|
||||
logger.debug("[Finish entity KG process]")
|
||||
processed_word_kg = self._word_kg_process()
|
||||
logger.debug("[Finish word KG process]")
|
||||
movie_entity_ids = json.load(open(os.path.join(self.dpath, 'movie_ids.json'), 'r', encoding='utf-8'))
|
||||
logger.debug('[Load movie entity ids]')
|
||||
|
||||
side_data = {
|
||||
"entity_kg": processed_entity_kg,
|
||||
"word_kg": processed_word_kg,
|
||||
"item_entity_ids": movie_entity_ids,
|
||||
}
|
||||
return side_data
|
||||
|
||||
def _entity_kg_process(self):
|
||||
edge_list = [] # [(entity, entity, relation)]
|
||||
for line in self.entity_kg:
|
||||
triple = line.strip().split('\t')
|
||||
e0 = self.entity2id[triple[0]]
|
||||
e1 = self.entity2id[triple[2]]
|
||||
r = triple[1]
|
||||
edge_list.append((e0, e1, r))
|
||||
edge_list.append((e1, e0, r))
|
||||
edge_list.append((e0, e0, 'SELF_LOOP'))
|
||||
if e1 != e0:
|
||||
edge_list.append((e1, e1, 'SELF_LOOP'))
|
||||
|
||||
relation_cnt, relation2id, edges, entities = defaultdict(int), dict(), set(), set()
|
||||
for h, t, r in edge_list:
|
||||
relation_cnt[r] += 1
|
||||
for h, t, r in edge_list:
|
||||
if r not in relation2id:
|
||||
relation2id[r] = len(relation2id)
|
||||
edges.add((h, t, relation2id[r]))
|
||||
entities.add(self.id2entity[h])
|
||||
entities.add(self.id2entity[t])
|
||||
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'n_relation': len(relation2id),
|
||||
'entity': list(entities)
|
||||
}
|
||||
|
||||
def _word_kg_process(self):
|
||||
edges = set() # {(entity, entity)}
|
||||
entities = set()
|
||||
for line in self.word_kg:
|
||||
triple = line.strip().split('\t')
|
||||
entities.add(triple[0])
|
||||
entities.add(triple[2])
|
||||
e0 = self.word2id[triple[0]]
|
||||
e1 = self.word2id[triple[2]]
|
||||
edges.add((e0, e1))
|
||||
edges.add((e1, e0))
|
||||
# edge_set = [[co[0] for co in list(edges)], [co[1] for co in list(edges)]]
|
||||
return {
|
||||
'edge': list(edges),
|
||||
'entity': list(entities)
|
||||
}
|
275
HyCoRec/crslab/download.py
Normal file
275
HyCoRec/crslab/download.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/7
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/7
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import datetime
|
||||
import requests
|
||||
import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DownloadableFile:
|
||||
"""
|
||||
A class used to abstract any file that has to be downloaded online.
|
||||
|
||||
Any task that needs to download a file needs to have a list RESOURCES
|
||||
that have objects of this class as elements.
|
||||
|
||||
This class provides the following functionality:
|
||||
|
||||
- Download a file from a URL
|
||||
- Untar the file if zipped
|
||||
- Checksum for the downloaded file
|
||||
|
||||
An object of this class needs to be created with:
|
||||
|
||||
- url <string> : URL or Google Drive id to download from
|
||||
- file_name <string> : File name that the file should be named
|
||||
- hashcode <string> : SHA256 hashcode of the downloaded file
|
||||
- zipped <boolean> : False if the file is not compressed
|
||||
- from_google <boolean> : True if the file is from Google Drive
|
||||
"""
|
||||
|
||||
def __init__(self, url, file_name, hashcode, zipped=True, from_google=False):
|
||||
self.url = url
|
||||
self.file_name = file_name
|
||||
self.hashcode = hashcode
|
||||
self.zipped = zipped
|
||||
self.from_google = from_google
|
||||
|
||||
def checksum(self, dpath):
|
||||
"""
|
||||
Checksum on a given file.
|
||||
|
||||
:param dpath: path to the downloaded file.
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(os.path.join(dpath, self.file_name), "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(65536), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
if sha256_hash.hexdigest() != self.hashcode:
|
||||
# remove_dir(dpath)
|
||||
raise AssertionError(
|
||||
f"[ Checksum for {self.file_name} from \n{self.url}\n"
|
||||
"does not match the expected checksum. Please try again. ]"
|
||||
)
|
||||
else:
|
||||
logger.debug("Checksum Successful")
|
||||
pass
|
||||
|
||||
def download_file(self, dpath):
|
||||
if self.from_google:
|
||||
download_from_google_drive(self.url, os.path.join(dpath, self.file_name))
|
||||
else:
|
||||
download(self.url, dpath, self.file_name)
|
||||
|
||||
self.checksum(dpath)
|
||||
|
||||
if self.zipped:
|
||||
untar(dpath, self.file_name)
|
||||
|
||||
|
||||
def download(url, path, fname, redownload=False, num_retries=5):
|
||||
"""
|
||||
Download file using `requests`.
|
||||
If ``redownload`` is set to false, then will not download tar file again if it is
|
||||
present (default ``False``).
|
||||
"""
|
||||
outfile = os.path.join(path, fname)
|
||||
download = not os.path.exists(outfile) or redownload
|
||||
logger.info(f"Downloading {url} to {outfile}")
|
||||
retry = num_retries
|
||||
exp_backoff = [2 ** r for r in reversed(range(retry))]
|
||||
|
||||
pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname))
|
||||
|
||||
while download and retry > 0:
|
||||
response = None
|
||||
try:
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.60',
|
||||
}
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
|
||||
# negative reply could be 'none' or just missing
|
||||
CHUNK_SIZE = 32768
|
||||
total_size = int(response.headers.get('Content-Length', -1))
|
||||
# server returns remaining size if resuming, so adjust total
|
||||
pbar.total = total_size
|
||||
done = 0
|
||||
|
||||
with open(outfile, 'wb') as f:
|
||||
for chunk in response.iter_content(CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if total_size > 0:
|
||||
done += len(chunk)
|
||||
if total_size < done:
|
||||
# don't freak out if content-length was too small
|
||||
total_size = done
|
||||
pbar.total = total_size
|
||||
pbar.update(len(chunk))
|
||||
break
|
||||
except (
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.ReadTimeout,
|
||||
):
|
||||
retry -= 1
|
||||
pbar.clear()
|
||||
if retry > 0:
|
||||
pl = 'y' if retry == 1 else 'ies'
|
||||
logger.debug(
|
||||
f'Connection error, retrying. ({retry} retr{pl} left)'
|
||||
)
|
||||
time.sleep(exp_backoff[retry])
|
||||
else:
|
||||
logger.error('Retried too many times, stopped retrying.')
|
||||
finally:
|
||||
if response:
|
||||
response.close()
|
||||
if retry <= 0:
|
||||
raise RuntimeError('Connection broken too many times. Stopped retrying.')
|
||||
|
||||
if download and retry > 0:
|
||||
pbar.update(done - pbar.n)
|
||||
if done < total_size:
|
||||
raise RuntimeError(
|
||||
f'Received less data than specified in Content-Length header for '
|
||||
f'{url}. There may be a download problem.'
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
|
||||
|
||||
def _get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def download_from_google_drive(gd_id, destination):
|
||||
"""
|
||||
Use the requests package to download a file from Google Drive.
|
||||
"""
|
||||
URL = 'https://docs.google.com/uc?export=download'
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.get(URL, params={'id': gd_id}, stream=True)
|
||||
token = _get_confirm_token(response)
|
||||
|
||||
if token:
|
||||
response.close()
|
||||
params = {'id': gd_id, 'confirm': token}
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
CHUNK_SIZE = 32768
|
||||
with open(destination, 'wb') as f:
|
||||
for chunk in response.iter_content(CHUNK_SIZE):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
response.close()
|
||||
|
||||
|
||||
def move(path1, path2):
|
||||
"""
|
||||
Rename the given file.
|
||||
"""
|
||||
shutil.move(path1, path2)
|
||||
|
||||
|
||||
def untar(path, fname, deleteTar=True):
|
||||
"""
|
||||
Unpack the given archive file to the same directory.
|
||||
|
||||
:param str path:
|
||||
The folder containing the archive. Will contain the contents.
|
||||
|
||||
:param str fname:
|
||||
The filename of the archive file.
|
||||
|
||||
:param bool deleteTar:
|
||||
If true, the archive will be deleted after extraction.
|
||||
"""
|
||||
logger.debug(f'unpacking {fname}')
|
||||
fullpath = os.path.join(path, fname)
|
||||
shutil.unpack_archive(fullpath, path)
|
||||
if deleteTar:
|
||||
os.remove(fullpath)
|
||||
|
||||
|
||||
def make_dir(path):
|
||||
"""
|
||||
Make the directory and any nonexistent parent directories (`mkdir -p`).
|
||||
"""
|
||||
# the current working directory is a fine path
|
||||
if path != '':
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def remove_dir(path):
|
||||
"""
|
||||
Remove the given directory, if it exists.
|
||||
"""
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
|
||||
|
||||
def check_build(path, version_string=None):
|
||||
"""
|
||||
Check if '.built' flag has been set for that task.
|
||||
|
||||
If a version_string is provided, this has to match, or the version is regarded as
|
||||
not built.
|
||||
"""
|
||||
if version_string:
|
||||
fname = os.path.join(path, '.built')
|
||||
if not os.path.isfile(fname):
|
||||
return False
|
||||
else:
|
||||
with open(fname, 'r') as read:
|
||||
text = read.read().split('\n')
|
||||
return len(text) > 1 and text[1] == version_string
|
||||
else:
|
||||
return os.path.isfile(os.path.join(path, '.built'))
|
||||
|
||||
|
||||
def mark_done(path, version_string=None):
|
||||
"""
|
||||
Mark this path as prebuilt.
|
||||
|
||||
Marks the path as done by adding a '.built' file with the current timestamp
|
||||
plus a version description string if specified.
|
||||
|
||||
:param str path:
|
||||
The file path to mark as built.
|
||||
|
||||
:param str version_string:
|
||||
The version of this dataset.
|
||||
"""
|
||||
with open(os.path.join(path, '.built'), 'w') as write:
|
||||
write.write(str(datetime.datetime.today()))
|
||||
if version_string:
|
||||
write.write('\n' + version_string)
|
||||
|
||||
|
||||
def build(dpath, dfile, version=None):
|
||||
if not check_build(dpath, version):
|
||||
logger.info('[Building data: ' + dpath + ']')
|
||||
if check_build(dpath):
|
||||
remove_dir(dpath)
|
||||
make_dir(dpath)
|
||||
# Download the data.
|
||||
downloadable_file = dfile
|
||||
downloadable_file.download_file(dpath)
|
||||
mark_done(dpath, version)
|
28
HyCoRec/crslab/evaluator/__init__.py
Normal file
28
HyCoRec/crslab/evaluator/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/22
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .standard import StandardEvaluator
|
||||
from ..data import dataset_language_map
|
||||
|
||||
Evaluator_register_table = {
|
||||
'standard': StandardEvaluator
|
||||
}
|
||||
|
||||
|
||||
def get_evaluator(evaluator_name, dataset, file_path):
|
||||
if evaluator_name in Evaluator_register_table:
|
||||
language = dataset_language_map[dataset]
|
||||
evaluator = Evaluator_register_table[evaluator_name](language, file_path)
|
||||
logger.info(f'[Build evaluator {evaluator_name}]')
|
||||
return evaluator
|
||||
else:
|
||||
raise NotImplementedError(f'Model [{evaluator_name}] has not been implemented')
|
31
HyCoRec/crslab/evaluator/base.py
Normal file
31
HyCoRec/crslab/evaluator/base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# @Time : 2020/11/30
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/30
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEvaluator(ABC):
|
||||
"""Base class for evaluator"""
|
||||
|
||||
def rec_evaluate(self, preds, label):
|
||||
pass
|
||||
|
||||
def gen_evaluate(self, preds, label):
|
||||
pass
|
||||
|
||||
def policy_evaluate(self, preds, label):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def report(self, epoch, mode):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_metrics(self):
|
||||
pass
|
30
HyCoRec/crslab/evaluator/embeddings.py
Normal file
30
HyCoRec/crslab/evaluator/embeddings.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/18
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/18
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
resources = {
|
||||
'zh': {
|
||||
'version': '0.2',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EVyPGnSEWZlGsLn0tpCa7BABjY7u3Ii6o_6aqYzDmw0xNw?download=1',
|
||||
'cc.zh.300.zip',
|
||||
'effd9806809a1db106b5166b817aaafaaf3f005846f730d4c49f88c7a28a0ac3'
|
||||
)
|
||||
},
|
||||
'en': {
|
||||
'version': '0.2',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ee3JyLp8wblAoQfFY7balSYB8g2wRebRek8QLOmYs8jcKw?download=1',
|
||||
'cc.en.300.zip',
|
||||
'96a06a77da70325997eaa52bfd9acb1359a7c3754cb1c1aed2fc27c04936d53e'
|
||||
)
|
||||
}
|
||||
}
|
4
HyCoRec/crslab/evaluator/metrics/__init__.py
Normal file
4
HyCoRec/crslab/evaluator/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Metric, Metrics, aggregate_unnamed_reports, AverageMetric
|
||||
from .gen import BleuMetric, ExactMatchMetric, F1Metric, DistMetric, EmbeddingAverage, VectorExtrema, \
|
||||
GreedyMatch
|
||||
from .rec import RECMetric, NDCGMetric, MRRMetric, CovMetric
|
232
HyCoRec/crslab/evaluator/metrics/base.py
Normal file
232
HyCoRec/crslab/evaluator/metrics/base.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/2
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from typing import Any, Union, List, Optional, Dict
|
||||
|
||||
TScalar = Union[int, float, torch.Tensor]
|
||||
TVector = Union[List[TScalar], torch.Tensor]
|
||||
|
||||
|
||||
@functools.total_ordering
|
||||
class Metric(ABC):
|
||||
"""
|
||||
Base class for storing metrics.
|
||||
|
||||
Subclasses should define .value(). Examples are provided for each subclass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def value(self):
|
||||
"""
|
||||
Return the value of the metric as a float.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __add__(self, other: Any) -> 'Metric':
|
||||
raise NotImplementedError
|
||||
|
||||
def __iadd__(self, other):
|
||||
return self.__radd__(other)
|
||||
|
||||
def __radd__(self, other: Any):
|
||||
if other is None:
|
||||
return self
|
||||
return self.__add__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.value():.4g}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.value():.4g})'
|
||||
|
||||
def __float__(self) -> float:
|
||||
return float(self.value())
|
||||
|
||||
def __int__(self) -> int:
|
||||
return int(self.value())
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Metric):
|
||||
return self.value() == other.value()
|
||||
else:
|
||||
return self.value() == other
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if isinstance(other, Metric):
|
||||
return self.value() < other.value()
|
||||
else:
|
||||
return self.value() < other
|
||||
|
||||
def __sub__(self, other: Any) -> float:
|
||||
"""
|
||||
Used heavily for assertAlmostEqual.
|
||||
"""
|
||||
if not isinstance(other, float):
|
||||
raise TypeError('Metrics.__sub__ is intentionally limited to floats.')
|
||||
return self.value() - other
|
||||
|
||||
def __rsub__(self, other: Any) -> float:
|
||||
"""
|
||||
Used heavily for assertAlmostEqual.
|
||||
|
||||
NOTE: This is not necessary in python 3.7+.
|
||||
"""
|
||||
if not isinstance(other, float):
|
||||
raise TypeError('Metrics.__rsub__ is intentionally limited to floats.')
|
||||
return other - self.value()
|
||||
|
||||
@classmethod
|
||||
def as_number(cls, obj: TScalar) -> Union[int, float]:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj_as_number: Union[int, float] = obj.item()
|
||||
else:
|
||||
obj_as_number = obj # type: ignore
|
||||
assert isinstance(obj_as_number, int) or isinstance(obj_as_number, float)
|
||||
return obj_as_number
|
||||
|
||||
@classmethod
|
||||
def as_float(cls, obj: TScalar) -> float:
|
||||
return float(cls.as_number(obj))
|
||||
|
||||
@classmethod
|
||||
def as_int(cls, obj: TScalar) -> int:
|
||||
return int(cls.as_number(obj))
|
||||
|
||||
@classmethod
|
||||
def many(cls, *objs: List[TVector]) -> List['Metric']:
|
||||
"""
|
||||
Construct many of a Metric from the base parts.
|
||||
|
||||
Useful if you separately compute numerators and denomenators, etc.
|
||||
"""
|
||||
lengths = [len(o) for o in objs]
|
||||
if len(set(lengths)) != 1:
|
||||
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
|
||||
return [cls(*items) for items in zip(*objs)]
|
||||
|
||||
|
||||
class SumMetric(Metric):
|
||||
"""
|
||||
Class that keeps a running sum of some metric.
|
||||
|
||||
Examples of SumMetric include things like "exs", the number of examples seen since
|
||||
the last report, which depends exactly on a teacher.
|
||||
"""
|
||||
|
||||
__slots__ = ('_sum',)
|
||||
|
||||
def __init__(self, sum_: TScalar = 0):
|
||||
if isinstance(sum_, torch.Tensor):
|
||||
self._sum = sum_.item()
|
||||
else:
|
||||
assert isinstance(sum_, (int, float, list))
|
||||
self._sum = sum_
|
||||
|
||||
def __add__(self, other: Optional['SumMetric']) -> 'SumMetric':
|
||||
# NOTE: hinting can be cleaned up with "from __future__ import annotations" when
|
||||
# we drop Python 3.6
|
||||
if other is None:
|
||||
return self
|
||||
full_sum = self._sum + other._sum
|
||||
# always keep the same return type
|
||||
return type(self)(sum_=full_sum)
|
||||
|
||||
def value(self) -> float:
|
||||
return self._sum
|
||||
|
||||
|
||||
class AverageMetric(Metric):
|
||||
"""
|
||||
Class that keeps a running average of some metric.
|
||||
|
||||
Examples of AverageMetrics include hits@1, F1, accuracy, etc. These metrics all have
|
||||
per-example values that can be directly mapped back to a teacher.
|
||||
"""
|
||||
|
||||
__slots__ = ('_numer', '_denom')
|
||||
|
||||
def __init__(self, numer: TScalar, denom: TScalar = 1):
|
||||
self._numer = self.as_number(numer)
|
||||
self._denom = self.as_number(denom)
|
||||
|
||||
def __add__(self, other: Optional['AverageMetric']) -> 'AverageMetric':
|
||||
# NOTE: hinting can be cleaned up with "from __future__ import annotations" when
|
||||
# we drop Python 3.6
|
||||
if other is None:
|
||||
return self
|
||||
full_numer: TScalar = self._numer + other._numer
|
||||
full_denom: TScalar = self._denom + other._denom
|
||||
# always keep the same return type
|
||||
return type(self)(numer=full_numer, denom=full_denom)
|
||||
|
||||
def value(self) -> float:
|
||||
if self._numer == 0 and self._denom == 0:
|
||||
# don't nan out if we haven't counted anything
|
||||
return 0.0
|
||||
if self._denom == 0:
|
||||
return float('nan')
|
||||
return self._numer / self._denom
|
||||
|
||||
|
||||
def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Metric]:
|
||||
"""
|
||||
Combines metrics without regard for tracking provenence.
|
||||
"""
|
||||
m: Dict[str, Metric] = {}
|
||||
for task_report in reports:
|
||||
for each_metric, value in task_report.items():
|
||||
m[each_metric] = m.get(each_metric) + value
|
||||
return m
|
||||
|
||||
|
||||
class Metrics(object):
|
||||
"""
|
||||
Metrics aggregator.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
|
||||
def __str__(self):
|
||||
return str(self._data)
|
||||
|
||||
def __repr__(self):
|
||||
return f'Metrics({repr(self._data)})'
|
||||
|
||||
def get(self, key: str):
|
||||
if key in self._data.keys():
|
||||
return self._data[key].value()
|
||||
else:
|
||||
raise
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
def add(self, key: str, value: Optional[Metric]) -> None:
|
||||
"""
|
||||
Record an accumulation to a metric.
|
||||
"""
|
||||
self._data[key] = self._data.get(key) + value
|
||||
|
||||
def report(self):
|
||||
"""
|
||||
Report the metrics over all data seen so far.
|
||||
"""
|
||||
return {k: (v if "cov" not in k else SumMetric(len(set(v.value())))) for k, v in self._data.items()}
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear all the metrics.
|
||||
"""
|
||||
self._data.clear()
|
158
HyCoRec/crslab/evaluator/metrics/gen.py
Normal file
158
HyCoRec/crslab/evaluator/metrics/gen.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# @Time : 2020/11/30
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/18
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from nltk import ngrams
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from typing import List, Optional
|
||||
|
||||
from crslab.evaluator.metrics.base import AverageMetric, SumMetric
|
||||
|
||||
re_art = re.compile(r'\b(a|an|the)\b')
|
||||
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
|
||||
re_space = re.compile(r'\s+')
|
||||
|
||||
|
||||
class PPLMetric(AverageMetric):
|
||||
def value(self):
|
||||
return math.exp(super().value())
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""
|
||||
Lower text and remove punctuation, articles and extra whitespace.
|
||||
"""
|
||||
|
||||
s = s.lower()
|
||||
s = re_punc.sub(' ', s)
|
||||
s = re_art.sub(' ', s)
|
||||
s = re_space.sub(' ', s)
|
||||
# s = ' '.join(s.split())
|
||||
return s
|
||||
|
||||
|
||||
class ExactMatchMetric(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(guess: str, answers: List[str]) -> 'ExactMatchMetric':
|
||||
if guess is None or answers is None:
|
||||
return None
|
||||
for a in answers:
|
||||
if guess == a:
|
||||
return ExactMatchMetric(1)
|
||||
return ExactMatchMetric(0)
|
||||
|
||||
|
||||
class F1Metric(AverageMetric):
|
||||
"""
|
||||
Helper class which computes token-level F1.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _prec_recall_f1_score(pred_items, gold_items):
|
||||
"""
|
||||
Compute precision, recall and f1 given a set of gold and prediction items.
|
||||
|
||||
:param pred_items: iterable of predicted values
|
||||
:param gold_items: iterable of gold values
|
||||
|
||||
:return: tuple (p, r, f1) for precision, recall, f1
|
||||
"""
|
||||
common = Counter(gold_items) & Counter(pred_items)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_items)
|
||||
recall = 1.0 * num_same / len(gold_items)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
@staticmethod
|
||||
def compute(guess: str, answers: List[str]) -> 'F1Metric':
|
||||
if guess is None or answers is None:
|
||||
return AverageMetric(0, 0)
|
||||
g_tokens = guess.split()
|
||||
scores = [
|
||||
F1Metric._prec_recall_f1_score(g_tokens, a.split())
|
||||
for a in answers
|
||||
]
|
||||
return F1Metric(max(scores), 1)
|
||||
|
||||
|
||||
class BleuMetric(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(guess: str, answers: List[str], k: int) -> Optional['BleuMetric']:
|
||||
"""
|
||||
Compute approximate BLEU score between guess and a set of answers.
|
||||
"""
|
||||
|
||||
weights = [0] * 4
|
||||
weights[k - 1] = 1
|
||||
score = sentence_bleu(
|
||||
[a.split(" ") for a in answers],
|
||||
guess.split(" "),
|
||||
weights=weights,
|
||||
)
|
||||
return BleuMetric(score)
|
||||
|
||||
|
||||
class DistMetric(SumMetric):
|
||||
@staticmethod
|
||||
def compute(sent: str, k: int) -> 'DistMetric':
|
||||
token_set = set()
|
||||
for token in ngrams(sent.split(), k):
|
||||
token_set.add(token)
|
||||
return DistMetric(len(token_set))
|
||||
|
||||
|
||||
class EmbeddingAverage(AverageMetric):
|
||||
@staticmethod
|
||||
def _avg_embedding(embedding):
|
||||
return np.sum(embedding, axis=0) / (np.linalg.norm(np.sum(embedding, axis=0)) + 1e-12)
|
||||
|
||||
@staticmethod
|
||||
def compute(hyp_embedding, ref_embeddings) -> 'EmbeddingAverage':
|
||||
hyp_avg_emb = EmbeddingAverage._avg_embedding(hyp_embedding).reshape(1, -1)
|
||||
ref_avg_embs = [EmbeddingAverage._avg_embedding(emb) for emb in ref_embeddings]
|
||||
ref_avg_embs = np.array(ref_avg_embs)
|
||||
return EmbeddingAverage(float(cosine_similarity(hyp_avg_emb, ref_avg_embs).max()))
|
||||
|
||||
|
||||
class VectorExtrema(AverageMetric):
|
||||
@staticmethod
|
||||
def _extreme_embedding(embedding):
|
||||
max_emb = np.max(embedding, axis=0)
|
||||
min_emb = np.min(embedding, axis=0)
|
||||
extreme_emb = np.fromiter(
|
||||
map(lambda x, y: x if ((x > y or x < -y) and y > 0) or ((x < y or x > -y) and y < 0) else y, max_emb,
|
||||
min_emb), dtype=float)
|
||||
return extreme_emb
|
||||
|
||||
@staticmethod
|
||||
def compute(hyp_embedding, ref_embeddings) -> 'VectorExtrema':
|
||||
hyp_ext_emb = VectorExtrema._extreme_embedding(hyp_embedding).reshape(1, -1)
|
||||
ref_ext_embs = [VectorExtrema._extreme_embedding(emb) for emb in ref_embeddings]
|
||||
ref_ext_embs = np.asarray(ref_ext_embs)
|
||||
return VectorExtrema(float(cosine_similarity(hyp_ext_emb, ref_ext_embs).max()))
|
||||
|
||||
|
||||
class GreedyMatch(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(hyp_embedding, ref_embeddings) -> 'GreedyMatch':
|
||||
hyp_emb = np.asarray(hyp_embedding)
|
||||
ref_embs = (np.asarray(ref_embedding) for ref_embedding in ref_embeddings)
|
||||
score_max = 0
|
||||
for ref_emb in ref_embs:
|
||||
sim_mat = cosine_similarity(hyp_emb, ref_emb)
|
||||
score_max = max(score_max, (sim_mat.max(axis=0).mean() + sim_mat.max(axis=1).mean()) / 2)
|
||||
return GreedyMatch(score_max)
|
41
HyCoRec/crslab/evaluator/metrics/rec.py
Normal file
41
HyCoRec/crslab/evaluator/metrics/rec.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# @Time : 2020/11/30
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/2
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
import math
|
||||
|
||||
from crslab.evaluator.metrics.base import AverageMetric, SumMetric
|
||||
|
||||
|
||||
class RECMetric(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(ranks, label, k) -> 'RECMetric':
|
||||
return RECMetric(int(label in ranks[:k]))
|
||||
|
||||
|
||||
class NDCGMetric(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(ranks, label, k) -> 'NDCGMetric':
|
||||
if label in ranks[:k]:
|
||||
label_rank = ranks.index(label)
|
||||
return NDCGMetric(1 / math.log2(label_rank + 2))
|
||||
return NDCGMetric(0)
|
||||
|
||||
|
||||
class MRRMetric(AverageMetric):
|
||||
@staticmethod
|
||||
def compute(ranks, label, k) -> 'MRRMetric':
|
||||
if label in ranks[:k]:
|
||||
label_rank = ranks.index(label)
|
||||
return MRRMetric(1 / (label_rank + 1))
|
||||
return MRRMetric(0)
|
||||
|
||||
|
||||
class CovMetric(SumMetric):
|
||||
@staticmethod
|
||||
def compute(ranks, label, k) -> 'CovMetric':
|
||||
return CovMetric(ranks[:k])
|
86
HyCoRec/crslab/evaluator/standard.py
Normal file
86
HyCoRec/crslab/evaluator/standard.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# @Time : 2020/11/30
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/18
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
from collections import defaultdict
|
||||
from time import perf_counter
|
||||
from loguru import logger
|
||||
from nltk import ngrams
|
||||
|
||||
from crslab.evaluator.base import BaseEvaluator
|
||||
from crslab.evaluator.utils import nice_report
|
||||
from .metrics import *
|
||||
|
||||
|
||||
class StandardEvaluator(BaseEvaluator):
|
||||
"""The evaluator for all kind of model(recommender, conversation, policy)
|
||||
|
||||
Args:
|
||||
rec_metrics: the metrics to evaluate recommender model, including hit@K, ndcg@K and mrr@K
|
||||
dist_set: the set to record dist n-gram
|
||||
dist_cnt: the count of dist n-gram evaluation
|
||||
gen_metrics: the metrics to evaluate conversational model, including bleu, dist, embedding metrics, f1
|
||||
optim_metrics: the metrics to optimize in training
|
||||
"""
|
||||
|
||||
def __init__(self, language, file_path=None):
|
||||
super(StandardEvaluator, self).__init__()
|
||||
self.file_path = file_path
|
||||
self.result_data = []
|
||||
# rec
|
||||
self.rec_metrics = Metrics()
|
||||
# gen
|
||||
self.dist_set = defaultdict(set)
|
||||
self.dist_cnt = 0
|
||||
self.gen_metrics = Metrics()
|
||||
# optim
|
||||
self.optim_metrics = Metrics()
|
||||
|
||||
def rec_evaluate(self, ranks, label):
|
||||
for k in [1, 10, 50]:
|
||||
if len(ranks) >= k:
|
||||
self.rec_metrics.add(f"recall@{k}", RECMetric.compute(ranks, label, k))
|
||||
self.rec_metrics.add(f"ndcg@{k}", NDCGMetric.compute(ranks, label, k))
|
||||
self.rec_metrics.add(f"mrr@{k}", MRRMetric.compute(ranks, label, k))
|
||||
|
||||
def gen_evaluate(self, hyp, refs, seq=None):
|
||||
if hyp:
|
||||
self.gen_metrics.add("f1", F1Metric.compute(hyp, refs))
|
||||
|
||||
for k in range(1, 5):
|
||||
self.gen_metrics.add(f"bleu@{k}", BleuMetric.compute(hyp, refs, k))
|
||||
for token in ngrams(seq, k):
|
||||
self.dist_set[f"dist@{k}"].add(token)
|
||||
self.dist_cnt += 1
|
||||
|
||||
def report(self, epoch=-1, mode='test'):
|
||||
for k, v in self.dist_set.items():
|
||||
self.gen_metrics.add(k, AverageMetric(len(v) / self.dist_cnt))
|
||||
reports = [self.rec_metrics.report(), self.gen_metrics.report(), self.optim_metrics.report()]
|
||||
all_reports = aggregate_unnamed_reports(reports)
|
||||
self.result_data.append({
|
||||
'epoch': epoch,
|
||||
'mode': mode,
|
||||
'report': {k:all_reports[k].value() for k in all_reports}
|
||||
})
|
||||
if self.file_path:
|
||||
json.dump(self.result_data, open(self.file_path, "w", encoding="utf-8"), indent=4, ensure_ascii=False)
|
||||
logger.info('\n' + nice_report(all_reports))
|
||||
|
||||
def reset_metrics(self):
|
||||
# rec
|
||||
self.rec_metrics.clear()
|
||||
# conv
|
||||
self.gen_metrics.clear()
|
||||
self.dist_cnt = 0
|
||||
self.dist_set.clear()
|
||||
# optim
|
||||
self.optim_metrics.clear()
|
160
HyCoRec/crslab/evaluator/utils.py
Normal file
160
HyCoRec/crslab/evaluator/utils.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/12/17
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/12/17
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typing import Union, Tuple
|
||||
|
||||
from .metrics import Metric
|
||||
|
||||
|
||||
def _line_width():
|
||||
try:
|
||||
# if we're in an interactive ipython notebook, hardcode a longer width
|
||||
__IPYTHON__ # type: ignore
|
||||
return 128
|
||||
except NameError:
|
||||
return shutil.get_terminal_size((88, 24)).columns
|
||||
|
||||
|
||||
def float_formatter(f: Union[float, int]) -> str:
|
||||
"""
|
||||
Format a float as a pretty string.
|
||||
"""
|
||||
if f != f:
|
||||
# instead of returning nan, return "" so it shows blank in table
|
||||
return ""
|
||||
if isinstance(f, int):
|
||||
# don't do any rounding of integers, leave them alone
|
||||
return str(f)
|
||||
if f >= 1000:
|
||||
# numbers > 1000 just round to the nearest integer
|
||||
s = f'{f:.0f}'
|
||||
else:
|
||||
# otherwise show 4 significant figures, regardless of decimal spot
|
||||
s = f'{f:.4g}'
|
||||
# replace leading 0's with blanks for easier reading
|
||||
# example: -0.32 to -.32
|
||||
s = s.replace('-0.', '-.')
|
||||
if s.startswith('0.'):
|
||||
s = s[1:]
|
||||
# Add the trailing 0's to always show 4 digits
|
||||
# example: .32 to .3200
|
||||
if s[0] == '.' and len(s) < 5:
|
||||
s += '0' * (5 - len(s))
|
||||
return s
|
||||
|
||||
|
||||
def round_sigfigs(x: Union[float, 'torch.Tensor'], sigfigs=4) -> float:
|
||||
"""
|
||||
Round value to specified significant figures.
|
||||
|
||||
:param x: input number
|
||||
:param sigfigs: number of significant figures to return
|
||||
|
||||
:returns: float number rounded to specified sigfigs
|
||||
"""
|
||||
x_: float
|
||||
if isinstance(x, torch.Tensor):
|
||||
x_ = x.item()
|
||||
else:
|
||||
x_ = x # type: ignore
|
||||
|
||||
try:
|
||||
if x_ == 0:
|
||||
return 0
|
||||
return round(x_, -math.floor(math.log10(abs(x_)) - sigfigs + 1))
|
||||
except (ValueError, OverflowError) as ex:
|
||||
if x_ in [float('inf'), float('-inf')] or x_ != x_: # inf or nan
|
||||
return x_
|
||||
else:
|
||||
raise ex
|
||||
|
||||
|
||||
def _report_sort_key(report_key: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Sorting name for reports.
|
||||
|
||||
Sorts by main metric alphabetically, then by task.
|
||||
"""
|
||||
# if metric is on its own, like "f1", we will return ('', 'f1')
|
||||
# if metric is from multitask, we denote it.
|
||||
# e.g. "convai2/f1" -> ('convai2', 'f1')
|
||||
# we handle multiple cases of / because sometimes teacher IDs have
|
||||
# filenames.
|
||||
fields = report_key.split("/")
|
||||
main_key = fields.pop(-1)
|
||||
sub_key = '/'.join(fields)
|
||||
return (sub_key or 'all', main_key)
|
||||
|
||||
|
||||
def nice_report(report) -> str:
|
||||
"""
|
||||
Render an agent Report as a beautiful string.
|
||||
|
||||
If pandas is installed, we will use it to render as a table. Multitask
|
||||
metrics will be shown per row, e.g.
|
||||
|
||||
.. code-block:
|
||||
f1 ppl
|
||||
all .410 27.0
|
||||
task1 .400 32.0
|
||||
task2 .420 22.0
|
||||
|
||||
If pandas is not available, we will use a dict with like-metrics placed
|
||||
next to each other.
|
||||
"""
|
||||
if not report:
|
||||
return ""
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
use_pandas = True
|
||||
except ImportError:
|
||||
use_pandas = False
|
||||
|
||||
sorted_keys = sorted(report.keys(), key=_report_sort_key)
|
||||
output: OrderedDict[Union[str, Tuple[str, str]], float] = OrderedDict()
|
||||
for k in sorted_keys:
|
||||
v = report[k]
|
||||
if isinstance(v, Metric):
|
||||
v = v.value()
|
||||
if use_pandas:
|
||||
output[_report_sort_key(k)] = v
|
||||
else:
|
||||
output[k] = v
|
||||
|
||||
if use_pandas:
|
||||
line_width = _line_width()
|
||||
|
||||
df = pd.DataFrame([output])
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
||||
df = df.stack().transpose().droplevel(0, axis=1)
|
||||
result = " " + df.to_string(
|
||||
na_rep="",
|
||||
line_width=line_width - 3, # -3 for the extra spaces we add
|
||||
float_format=float_formatter,
|
||||
index=df.shape[0] > 1,
|
||||
).replace("\n\n", "\n").replace("\n", "\n ")
|
||||
result = re.sub(r"\s+$", "", result)
|
||||
return result
|
||||
else:
|
||||
return json.dumps(
|
||||
{
|
||||
k: round_sigfigs(v, 4) if isinstance(v, float) else v
|
||||
for k, v in output.items()
|
||||
}
|
||||
)
|
34
HyCoRec/crslab/model/__init__.py
Normal file
34
HyCoRec/crslab/model/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/24
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
# @Time : 2021/10/06
|
||||
# @Author : Zhipeng Zhao
|
||||
# @Email : oran_official@outlook.com
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .crs import *
|
||||
|
||||
Model_register_table = {
|
||||
'HyCoRec': HyCoRecModel,
|
||||
}
|
||||
|
||||
|
||||
def get_model(config, model_name, device, vocab, side_data=None):
|
||||
if model_name in Model_register_table:
|
||||
model = Model_register_table[model_name](config, device, vocab, side_data)
|
||||
logger.info(f'[Build model {model_name}]')
|
||||
if config.opt["gpu"] == [-1]:
|
||||
return model
|
||||
else:
|
||||
return torch.nn.DataParallel(model, device_ids=config["gpu"])
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Model [{}] has not been implemented'.format(model_name))
|
62
HyCoRec/crslab/model/base.py
Normal file
62
HyCoRec/crslab/model/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/29
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
from crslab.download import build
|
||||
|
||||
|
||||
class BaseModel(ABC, nn.Module):
|
||||
"""Base class for all models"""
|
||||
|
||||
def __init__(self, opt, device, dpath=None, resource=None):
|
||||
super(BaseModel, self).__init__()
|
||||
self.opt = opt
|
||||
self.device = device
|
||||
|
||||
if resource is not None:
|
||||
self.dpath = dpath
|
||||
dfile = resource['file']
|
||||
build(dpath, dfile, version=resource['version'])
|
||||
|
||||
self.build_model()
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self, *args, **kwargs):
|
||||
"""build model"""
|
||||
pass
|
||||
|
||||
def recommend(self, batch, mode):
|
||||
"""calculate loss and prediction of recommendation for batch under certain mode
|
||||
|
||||
Args:
|
||||
batch (dict or tuple): batch data
|
||||
mode (str, optional): train/valid/test.
|
||||
"""
|
||||
pass
|
||||
|
||||
def converse(self, batch, mode):
|
||||
"""calculate loss and prediction of conversation for batch under certain mode
|
||||
|
||||
Args:
|
||||
batch (dict or tuple): batch data
|
||||
mode (str, optional): train/valid/test.
|
||||
"""
|
||||
pass
|
||||
|
||||
def guide(self, batch, mode):
|
||||
"""calculate loss and prediction of guidance for batch under certain mode
|
||||
|
||||
Args:
|
||||
batch (dict or tuple): batch data
|
||||
mode (str, optional): train/valid/test.
|
||||
"""
|
||||
pass
|
1
HyCoRec/crslab/model/crs/__init__.py
Normal file
1
HyCoRec/crslab/model/crs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .hycorec import *
|
1
HyCoRec/crslab/model/crs/hycorec/__init__.py
Normal file
1
HyCoRec/crslab/model/crs/hycorec/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .hycorec import HyCoRecModel
|
28
HyCoRec/crslab/model/crs/hycorec/attention.py
Normal file
28
HyCoRec/crslab/model/crs/hycorec/attention.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/4/1
|
||||
# @Author : Chenzhan Shang
|
||||
# @Email : czshang@outlook.com
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Multi-head
|
||||
class MHItemAttention(nn.Module):
|
||||
def __init__(self, dim, head_num):
|
||||
super(MHItemAttention, self).__init__()
|
||||
self.MHA = torch.nn.MultiheadAttention(dim, head_num, batch_first=True)
|
||||
|
||||
def forward(self, related_entity, context_entity):
|
||||
"""
|
||||
input:
|
||||
related_entity: (n_r, dim)
|
||||
context_entity: (n_c, dim)
|
||||
output:
|
||||
related_context_entity: (n_c, dim)
|
||||
"""
|
||||
context_entity = torch.unsqueeze(context_entity, 0)
|
||||
related_entity = torch.unsqueeze(related_entity, 0)
|
||||
output, _ = self.MHA(context_entity, related_entity, related_entity)
|
||||
return torch.squeeze(output, 0)
|
200
HyCoRec/crslab/model/crs/hycorec/decoder.py
Normal file
200
HyCoRec/crslab/model/crs/hycorec/decoder.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _normalize, create_position_codes
|
||||
|
||||
|
||||
class TransformerDecoderLayerKG(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = embedding_size
|
||||
self.ffn_dim = ffn_size
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.self_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm_self_attention = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.session_item_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm_session_item_attention = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.related_encoder_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm_related_encoder_attention = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.context_encoder_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm_context_encoder_attention = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.norm_merge = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
|
||||
self.norm3 = nn.LayerNorm(embedding_size)
|
||||
|
||||
def forward(self, x, related_encoder_output, related_encoder_mask, context_encoder_output, context_encoder_mask,
|
||||
session_embedding, session_mask):
|
||||
decoder_mask = self._create_selfattn_mask(x)
|
||||
# first self attn
|
||||
residual = x
|
||||
# don't peak into the future!
|
||||
x = self.self_attention(query=x, mask=decoder_mask)
|
||||
x = self.dropout(x) # --dropout
|
||||
x = x + residual
|
||||
x = _normalize(x, self.norm_self_attention)
|
||||
|
||||
residual = x
|
||||
x = self.session_item_attention(
|
||||
query=x,
|
||||
key=session_embedding,
|
||||
value=session_embedding,
|
||||
mask=session_mask
|
||||
)
|
||||
x = self.dropout(x)
|
||||
x = residual + x
|
||||
x = _normalize(x, self.norm_session_item_attention)
|
||||
|
||||
residual = x
|
||||
related_x = self.related_encoder_attention(
|
||||
query=x,
|
||||
key=related_encoder_output,
|
||||
value=related_encoder_output,
|
||||
mask=related_encoder_mask
|
||||
)
|
||||
related_x = self.dropout(related_x) # --dropout
|
||||
|
||||
context_x = self.context_encoder_attention(
|
||||
query=x,
|
||||
key=context_encoder_output,
|
||||
value=context_encoder_output,
|
||||
mask=context_encoder_mask
|
||||
)
|
||||
context_x = self.dropout(context_x) # --dropout
|
||||
|
||||
x = related_x * 0.1 + context_x * 0.9 + residual
|
||||
x = _normalize(x, self.norm_merge)
|
||||
|
||||
# finally the ffn
|
||||
residual = x
|
||||
x = self.ffn(x)
|
||||
x = self.dropout(x) # --dropout
|
||||
x = residual + x
|
||||
x = _normalize(x, self.norm3)
|
||||
|
||||
return x
|
||||
|
||||
def _create_selfattn_mask(self, x):
|
||||
# figure out how many timestamps we need
|
||||
bsz = x.size(0)
|
||||
time = x.size(1)
|
||||
# make sure that we don't look into the future
|
||||
mask = torch.tril(x.new(time, time).fill_(1))
|
||||
# broadcast across batch
|
||||
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerDecoderKG(nn.Module):
|
||||
"""
|
||||
Transformer Decoder layer.
|
||||
|
||||
:param int n_heads: the number of multihead attention heads.
|
||||
:param int n_layers: number of transformer layers.
|
||||
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
|
||||
:param int ffn_size: the size of the hidden layer in the FFN
|
||||
:param embedding: an embedding matrix for the bottom layer of the transformer.
|
||||
If none, one is created for this encoder.
|
||||
:param float dropout: Dropout used around embeddings and before layer
|
||||
layer normalizations. This is used in Vaswani 2017 and works well on
|
||||
large datasets.
|
||||
:param float attention_dropout: Dropout performed after the multhead attention
|
||||
softmax. This is not used in Vaswani 2017.
|
||||
:param int padding_idx: Reserved padding index in the embeddings matrix.
|
||||
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
|
||||
used. If on, position embeddings are learned from scratch.
|
||||
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
|
||||
Found useful in fairseq.
|
||||
:param int n_positions: Size of the position embeddings matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
n_layers,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
vocabulary_size,
|
||||
embedding=None,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
embeddings_scale=True,
|
||||
learn_positional_embeddings=False,
|
||||
padding_idx=None,
|
||||
n_positions=1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.ffn_size = ffn_size
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.dim = embedding_size
|
||||
self.embeddings_scale = embeddings_scale
|
||||
self.dropout = nn.Dropout(p=dropout) # --dropout
|
||||
|
||||
self.out_dim = embedding_size
|
||||
assert embedding_size % n_heads == 0, \
|
||||
'Transformer embedding size must be a multiple of n_heads'
|
||||
|
||||
self.embeddings = embedding
|
||||
|
||||
# create the positional embeddings
|
||||
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
|
||||
if not learn_positional_embeddings:
|
||||
create_position_codes(
|
||||
n_positions, embedding_size, out=self.position_embeddings.weight
|
||||
)
|
||||
else:
|
||||
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
|
||||
|
||||
# build the model
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.layers.append(TransformerDecoderLayerKG(
|
||||
n_heads, embedding_size, ffn_size,
|
||||
attention_dropout=attention_dropout,
|
||||
relu_dropout=relu_dropout,
|
||||
dropout=dropout,
|
||||
))
|
||||
|
||||
def forward(self, input, related_encoder_state, context_encoder_state, session_state, incr_state=None):
|
||||
related_encoder_output, related_encoder_mask = related_encoder_state
|
||||
context_encoder_output, context_encoder_mask = context_encoder_state
|
||||
session_embedding, session_mask = session_state
|
||||
|
||||
seq_len = input.shape[1]
|
||||
positions = input.new_empty(seq_len).long()
|
||||
positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len)
|
||||
tensor = self.embeddings(input)
|
||||
if self.embeddings_scale:
|
||||
tensor = tensor * np.sqrt(self.dim)
|
||||
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
|
||||
tensor = self.dropout(tensor) # --dropout
|
||||
|
||||
for layer in self.layers:
|
||||
tensor = layer(tensor, related_encoder_output, related_encoder_mask, context_encoder_output,
|
||||
context_encoder_mask, session_embedding, session_mask)
|
||||
|
||||
return tensor, None
|
729
HyCoRec/crslab/model/crs/hycorec/hycorec.py
Normal file
729
HyCoRec/crslab/model/crs/hycorec/hycorec.py
Normal file
@@ -0,0 +1,729 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2021/5/26
|
||||
# @Author : Chenzhan Shang
|
||||
# @email : czshang@outlook.com
|
||||
|
||||
r"""
|
||||
PCR
|
||||
====
|
||||
References:
|
||||
Chen, Qibin, et al. `"Towards Knowledge-Based Recommender Dialog System."`_ in EMNLP 2019.
|
||||
|
||||
.. _`"Towards Knowledge-Based Recommender Dialog System."`:
|
||||
https://www.aclweb.org/anthology/D19-1189/
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import pickle
|
||||
from typing import List
|
||||
from time import perf_counter
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from torch_geometric.nn import RGCNConv, HypergraphConv
|
||||
|
||||
from crslab.config import DATA_PATH, DATASET_PATH
|
||||
from crslab.model.base import BaseModel
|
||||
from crslab.model.crs.hycorec.attention import MHItemAttention
|
||||
from crslab.model.utils.functions import edge_to_pyg_format
|
||||
from crslab.model.utils.modules.attention import SelfAttentionBatch, SelfAttentionSeq
|
||||
from crslab.model.utils.modules.transformer import TransformerEncoder
|
||||
from crslab.model.crs.hycorec.decoder import TransformerDecoderKG
|
||||
|
||||
|
||||
class HyCoRecModel(BaseModel):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
vocab_size: A integer indicating the vocabulary size.
|
||||
pad_token_idx: A integer indicating the id of padding token.
|
||||
start_token_idx: A integer indicating the id of start token.
|
||||
end_token_idx: A integer indicating the id of end token.
|
||||
token_emb_dim: A integer indicating the dimension of token embedding layer.
|
||||
pretrain_embedding: A string indicating the path of pretrained embedding.
|
||||
n_entity: A integer indicating the number of entities.
|
||||
n_relation: A integer indicating the number of relation in KG.
|
||||
num_bases: A integer indicating the number of bases.
|
||||
kg_emb_dim: A integer indicating the dimension of kg embedding.
|
||||
user_emb_dim: A integer indicating the dimension of user embedding.
|
||||
n_heads: A integer indicating the number of heads.
|
||||
n_layers: A integer indicating the number of layer.
|
||||
ffn_size: A integer indicating the size of ffn hidden.
|
||||
dropout: A float indicating the dropout rate.
|
||||
attention_dropout: A integer indicating the dropout rate of attention layer.
|
||||
relu_dropout: A integer indicating the dropout rate of relu layer.
|
||||
learn_positional_embeddings: A boolean indicating if we learn the positional embedding.
|
||||
embeddings_scale: A boolean indicating if we use the embeddings scale.
|
||||
reduction: A boolean indicating if we use the reduction.
|
||||
n_positions: A integer indicating the number of position.
|
||||
longest_label: A integer indicating the longest length for response generation.
|
||||
user_proj_dim: A integer indicating dim to project for user embedding.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, opt, device, vocab, side_data):
|
||||
"""
|
||||
|
||||
Args:
|
||||
opt (dict): A dictionary record the hyper parameters.
|
||||
device (torch.device): A variable indicating which device to place the data and model.
|
||||
vocab (dict): A dictionary record the vocabulary information.
|
||||
side_data (dict): A dictionary record the side data.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
self.gpu = opt.get("gpu", -1)
|
||||
self.dataset = opt.get("dataset", None)
|
||||
self.llm = opt.get("llm", "chatgpt-4o")
|
||||
assert self.dataset in ['HReDial', 'HTGReDial', 'DuRecDial', 'OpenDialKG', 'ReDial', 'TGReDial']
|
||||
# vocab
|
||||
self.pad_token_idx = vocab['tok2ind']['__pad__']
|
||||
self.start_token_idx = vocab['tok2ind']['__start__']
|
||||
self.end_token_idx = vocab['tok2ind']['__end__']
|
||||
self.vocab_size = vocab['vocab_size']
|
||||
self.token_emb_dim = opt.get('token_emb_dim', 300)
|
||||
self.pretrain_embedding = side_data.get('embedding', None)
|
||||
self.token2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "token2id.json"), "r", encoding="utf-8"))
|
||||
self.entity2id = json.load(open(os.path.join(DATASET_PATH, self.dataset.lower(), opt["tokenize"], "entity2id.json"), "r", encoding="utf-8"))
|
||||
# kg
|
||||
self.n_entity = vocab['n_entity']
|
||||
self.entity_kg = side_data['entity_kg']
|
||||
self.n_relation = self.entity_kg['n_relation']
|
||||
self.edge_idx, self.edge_type = edge_to_pyg_format(self.entity_kg['edge'], 'RGCN')
|
||||
self.edge_idx = self.edge_idx.to(device)
|
||||
self.edge_type = self.edge_type.to(device)
|
||||
self.num_bases = opt.get('num_bases', 8)
|
||||
self.kg_emb_dim = opt.get('kg_emb_dim', 300)
|
||||
self.user_emb_dim = self.kg_emb_dim
|
||||
# transformer
|
||||
self.n_heads = opt.get('n_heads', 2)
|
||||
self.n_layers = opt.get('n_layers', 2)
|
||||
self.ffn_size = opt.get('ffn_size', 300)
|
||||
self.dropout = opt.get('dropout', 0.1)
|
||||
self.attention_dropout = opt.get('attention_dropout', 0.0)
|
||||
self.relu_dropout = opt.get('relu_dropout', 0.1)
|
||||
self.embeddings_scale = opt.get('embedding_scale', True)
|
||||
self.learn_positional_embeddings = opt.get('learn_positional_embeddings', False)
|
||||
self.reduction = opt.get('reduction', False)
|
||||
self.n_positions = opt.get('n_positions', 1024)
|
||||
self.longest_label = opt.get('longest_label', 30)
|
||||
self.user_proj_dim = opt.get('user_proj_dim', 512)
|
||||
# pooling
|
||||
self.pooling = opt.get('pooling', None)
|
||||
assert self.pooling == 'Attn' or self.pooling == 'Mean'
|
||||
# MHA
|
||||
self.mha_n_heads = opt.get('mha_n_heads', 4)
|
||||
self.extension_strategy = opt.get('extension_strategy', None)
|
||||
self.pretrain = opt.get('pretrain', False)
|
||||
self.pretrain_data = None
|
||||
self.pretrain_epoch = opt.get('pretrain_epoch', 9999)
|
||||
|
||||
super(HyCoRecModel, self).__init__(opt, device)
|
||||
return
|
||||
|
||||
# 构建模型
|
||||
def build_model(self, *args, **kwargs):
|
||||
if self.pretrain:
|
||||
pretrain_file = os.path.join('pretrain', self.dataset, str(self.pretrain_epoch) + '-epoch.pth')
|
||||
self.pretrain_data = torch.load(pretrain_file, map_location=torch.device('cuda:' + str(self.gpu[0])))
|
||||
logger.info(f"[Load Pretrain Weights from {pretrain_file}]")
|
||||
# self._build_hredial_copy_mask()
|
||||
self._build_adjacent_matrix()
|
||||
# self._build_hllm_data()
|
||||
self._build_embedding()
|
||||
self._build_kg_layer()
|
||||
self._build_recommendation_layer()
|
||||
self._build_conversation_layer()
|
||||
|
||||
# 构建 mask
|
||||
def _build_hredial_copy_mask(self):
|
||||
token_filename = os.path.join(DATASET_PATH, "hredial", "nltk", "token2id.json")
|
||||
token_file = open(token_filename, 'r', encoding="utf-8")
|
||||
token2id = json.load(token_file)
|
||||
id2token = {token2id[token]: token for token in token2id}
|
||||
self.hredial_copy_mask = list()
|
||||
for i in range(len(id2token)):
|
||||
token = id2token[i]
|
||||
if token[0] == '@':
|
||||
self.hredial_copy_mask.append(True)
|
||||
else:
|
||||
self.hredial_copy_mask.append(False)
|
||||
self.hredial_copy_mask = torch.as_tensor(self.hredial_copy_mask).to(self.device)
|
||||
return
|
||||
|
||||
def _build_hllm_data(self):
|
||||
self.hllm_data_table = {
|
||||
"train": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_train_data.pkl"), "rb")),
|
||||
"valid": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_valid_data.pkl"), "rb")),
|
||||
"test": pickle.load(open(os.path.join(DATA_PATH, "hllm", self.dataset.lower(), self.llm, "hllm_test_data.pkl"), "rb")),
|
||||
}
|
||||
return
|
||||
|
||||
# 构建关联矩阵
|
||||
def _build_adjacent_matrix(self):
|
||||
entity2id = self.entity2id
|
||||
token2id = self.token2id
|
||||
item_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "item_edger.pkl"), "rb"))
|
||||
entity_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "entity_edger.pkl"), "rb"))
|
||||
word_edger = pickle.load(open(os.path.join(DATA_PATH, "edger", self.dataset.lower(), "word_edger.pkl"), "rb"))
|
||||
|
||||
item_adj = {}
|
||||
for item_a in item_edger:
|
||||
item_list = item_edger[item_a]
|
||||
if item_a not in entity2id:
|
||||
continue
|
||||
item_a = entity2id[item_a]
|
||||
item_list = []
|
||||
for item in item_list:
|
||||
if item not in entity2id:
|
||||
continue
|
||||
item_list.append(entity2id[item])
|
||||
item_adj[item_a] = item_list
|
||||
self.item_adj = item_adj
|
||||
|
||||
entity_adj = {}
|
||||
for entity_a in entity_edger:
|
||||
entity_list = entity_edger[entity_a]
|
||||
if entity_a not in entity2id:
|
||||
continue
|
||||
entity_a = entity2id[entity_a]
|
||||
entity_list = []
|
||||
for entity in entity_list:
|
||||
if entity not in entity2id:
|
||||
continue
|
||||
entity_list.append(entity2id[entity])
|
||||
entity_adj[entity_a] = entity_list
|
||||
self.entity_adj = entity_adj
|
||||
|
||||
word_adj = {}
|
||||
for word_a in word_edger:
|
||||
word_list = word_edger[word_a]
|
||||
if word_a not in token2id:
|
||||
continue
|
||||
word_a = token2id[word_a]
|
||||
word_list = []
|
||||
for word in word_list:
|
||||
if word not in token2id:
|
||||
continue
|
||||
word_list.append(token2id[word])
|
||||
word_adj[word_a] = word_list
|
||||
self.word_adj = word_adj
|
||||
|
||||
logger.info(f"[Adjacent Matrix built.]")
|
||||
return
|
||||
|
||||
# 构建编码层
|
||||
def _build_embedding(self):
|
||||
if self.pretrain_embedding is not None:
|
||||
self.token_embedding = nn.Embedding.from_pretrained(
|
||||
torch.as_tensor(self.pretrain_embedding, dtype=torch.float), freeze=False,
|
||||
padding_idx=self.pad_token_idx)
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx)
|
||||
nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
|
||||
nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0)
|
||||
|
||||
self.entity_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0)
|
||||
nn.init.normal_(self.entity_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
|
||||
nn.init.constant_(self.entity_embedding.weight[0], 0)
|
||||
self.word_embedding = nn.Embedding(self.n_entity, self.kg_emb_dim, 0)
|
||||
nn.init.normal_(self.word_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5)
|
||||
nn.init.constant_(self.word_embedding.weight[0], 0)
|
||||
logger.debug('[Build embedding]')
|
||||
return
|
||||
|
||||
# 构建超图编码层
|
||||
def _build_kg_layer(self):
|
||||
# graph encoder
|
||||
self.item_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
|
||||
self.entity_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
|
||||
self.word_encoder = RGCNConv(self.kg_emb_dim, self.kg_emb_dim, self.n_relation, num_bases=self.num_bases)
|
||||
if self.pretrain:
|
||||
self.item_encoder.load_state_dict(self.pretrain_data['encoder'])
|
||||
# hypergraph convolution
|
||||
self.hyper_conv_item = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
|
||||
self.hyper_conv_entity = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
|
||||
self.hyper_conv_word = HypergraphConv(self.kg_emb_dim, self.kg_emb_dim)
|
||||
# attention type
|
||||
self.item_attn = MHItemAttention(self.kg_emb_dim, self.mha_n_heads)
|
||||
# pooling
|
||||
if self.pooling == 'Attn':
|
||||
self.kg_attn = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)
|
||||
self.kg_attn_his = SelfAttentionBatch(self.kg_emb_dim, self.kg_emb_dim)
|
||||
logger.debug('[Build kg layer]')
|
||||
return
|
||||
|
||||
# 构建推荐模块
|
||||
def _build_recommendation_layer(self):
|
||||
self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity)
|
||||
self.rec_loss = nn.CrossEntropyLoss()
|
||||
logger.debug('[Build recommendation layer]')
|
||||
return
|
||||
|
||||
# 构建对话模块
|
||||
def _build_conversation_layer(self):
|
||||
self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long))
|
||||
self.entity_to_token = nn.Linear(self.kg_emb_dim, self.token_emb_dim, bias=True)
|
||||
self.related_encoder = TransformerEncoder(
|
||||
self.n_heads,
|
||||
self.n_layers,
|
||||
self.token_emb_dim,
|
||||
self.ffn_size,
|
||||
self.vocab_size,
|
||||
self.token_embedding,
|
||||
self.dropout,
|
||||
self.attention_dropout,
|
||||
self.relu_dropout,
|
||||
self.pad_token_idx,
|
||||
self.learn_positional_embeddings,
|
||||
self.embeddings_scale,
|
||||
self.reduction,
|
||||
self.n_positions
|
||||
)
|
||||
self.context_encoder = TransformerEncoder(
|
||||
self.n_heads,
|
||||
self.n_layers,
|
||||
self.token_emb_dim,
|
||||
self.ffn_size,
|
||||
self.vocab_size,
|
||||
self.token_embedding,
|
||||
self.dropout,
|
||||
self.attention_dropout,
|
||||
self.relu_dropout,
|
||||
self.pad_token_idx,
|
||||
self.learn_positional_embeddings,
|
||||
self.embeddings_scale,
|
||||
self.reduction,
|
||||
self.n_positions
|
||||
)
|
||||
self.decoder = TransformerDecoderKG(
|
||||
self.n_heads,
|
||||
self.n_layers,
|
||||
self.token_emb_dim,
|
||||
self.ffn_size,
|
||||
self.vocab_size,
|
||||
self.token_embedding,
|
||||
self.dropout,
|
||||
self.attention_dropout,
|
||||
self.relu_dropout,
|
||||
self.embeddings_scale,
|
||||
self.learn_positional_embeddings,
|
||||
self.pad_token_idx,
|
||||
self.n_positions
|
||||
)
|
||||
self.user_proj_1 = nn.Linear(self.user_emb_dim, self.user_proj_dim)
|
||||
self.user_proj_2 = nn.Linear(self.user_proj_dim, self.vocab_size)
|
||||
self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx)
|
||||
|
||||
self.copy_proj_1 = nn.Linear(2 * self.token_emb_dim, self.token_emb_dim)
|
||||
self.copy_proj_2 = nn.Linear(self.token_emb_dim, self.vocab_size)
|
||||
logger.debug('[Build conversation layer]')
|
||||
return
|
||||
|
||||
# 获取超图
|
||||
def _get_hypergraph(self, related, adj):
|
||||
related_items_set = set()
|
||||
for related_items in related:
|
||||
related_items_set.add(related_items)
|
||||
session_related_items = list(related_items_set)
|
||||
|
||||
hypergraph_nodes, hypergraph_edges, hyper_edge_counter = list(), list(), 0
|
||||
for item in session_related_items:
|
||||
hypergraph_nodes.append(item)
|
||||
hypergraph_edges.append(hyper_edge_counter)
|
||||
neighbors = list(adj.get(item, []))
|
||||
hypergraph_nodes += neighbors
|
||||
hypergraph_edges += [hyper_edge_counter] * len(neighbors)
|
||||
hyper_edge_counter += 1
|
||||
hyper_edge_index = torch.tensor([hypergraph_nodes, hypergraph_edges], device=self.device)
|
||||
return list(set(hypergraph_nodes)), hyper_edge_index
|
||||
|
||||
# 获取聚合
|
||||
def _get_embedding(self, hypergraph_items, embedding, tot_sub, adj):
|
||||
knowledge_embedding_list = []
|
||||
for item in hypergraph_items:
|
||||
sub_graph = [item] + list(adj.get(item, []))
|
||||
sub_graph = [tot_sub[item] for item in sub_graph]
|
||||
sub_graph_embedding = embedding[sub_graph]
|
||||
sub_graph_embedding = torch.mean(sub_graph_embedding, dim=0)
|
||||
knowledge_embedding_list.append(sub_graph_embedding)
|
||||
res_embedding = torch.zeros(1, self.kg_emb_dim).to(self.device)
|
||||
if len(knowledge_embedding_list) > 0:
|
||||
res_embedding = torch.stack(knowledge_embedding_list, dim=0)
|
||||
return res_embedding
|
||||
|
||||
@staticmethod
|
||||
def flatten(inputs):
|
||||
outputs = set()
|
||||
for li in inputs:
|
||||
for i in li:
|
||||
outputs.add(i)
|
||||
return list(outputs)
|
||||
|
||||
# 注意力融合特征向量
|
||||
def _attention_and_gating(self, session_embedding, knowledge_embedding, conceptnet_embedding, context_embedding):
|
||||
related_embedding = torch.cat((session_embedding, knowledge_embedding, conceptnet_embedding), dim=0)
|
||||
if context_embedding is None:
|
||||
if self.pooling == 'Attn':
|
||||
user_repr = self.kg_attn_his(related_embedding)
|
||||
else:
|
||||
assert self.pooling == 'Mean'
|
||||
user_repr = torch.mean(related_embedding, dim=0)
|
||||
elif self.pooling == 'Attn':
|
||||
attentive_related_embedding = self.item_attn(related_embedding, context_embedding)
|
||||
user_repr = self.kg_attn_his(attentive_related_embedding)
|
||||
user_repr = torch.unsqueeze(user_repr, dim=0)
|
||||
user_repr = torch.cat((context_embedding, user_repr), dim=0)
|
||||
user_repr = self.kg_attn(user_repr)
|
||||
else:
|
||||
assert self.pooling == 'Mean'
|
||||
attentive_related_embedding = self.item_attn(related_embedding, context_embedding)
|
||||
user_repr = torch.mean(attentive_related_embedding, dim=0)
|
||||
user_repr = torch.unsqueeze(user_repr, dim=0)
|
||||
user_repr = torch.cat((context_embedding, user_repr), dim=0)
|
||||
user_repr = torch.mean(user_repr, dim=0)
|
||||
return user_repr
|
||||
|
||||
def _get_hllm_embedding(self, tot_embedding, hllm_hyper_graph, adj, conv):
|
||||
hllm_hyper_edge_A = []
|
||||
hllm_hyper_edge_B = []
|
||||
for idx, hyper_edge in enumerate(hllm_hyper_graph):
|
||||
hllm_hyper_edge_A += [item for item in hyper_edge]
|
||||
hllm_hyper_edge_B += [idx] * len(hyper_edge)
|
||||
|
||||
hllm_items = list(set(hllm_hyper_edge_A))
|
||||
sub_item2id = {item:idx for idx, item in enumerate(hllm_items)}
|
||||
sub_embedding = tot_embedding[hllm_items]
|
||||
|
||||
hllm_hyper_edge = [[sub_item2id[item] for item in hllm_hyper_edge_A], hllm_hyper_edge_B]
|
||||
hllm_hyper_edge = torch.LongTensor(hllm_hyper_edge).to(self.device)
|
||||
|
||||
embedding = conv(sub_embedding, hllm_hyper_edge)
|
||||
|
||||
return embedding
|
||||
|
||||
def encode_user_repr(self, related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
|
||||
# COLD START
|
||||
# if len(related_items) == 0 or len(related_words) == 0:
|
||||
# if len(related_entities) == 0:
|
||||
# user_repr = torch.zeros(self.user_emb_dim, device=self.device)
|
||||
# elif self.pooling == 'Attn':
|
||||
# user_repr = tot_entity_embedding[related_entities]
|
||||
# user_repr = self.kg_attn(user_repr)
|
||||
# else:
|
||||
# assert self.pooling == 'Mean'
|
||||
# user_repr = tot_entity_embedding[related_entities]
|
||||
# user_repr = torch.mean(user_repr, dim=0)
|
||||
# return user_repr
|
||||
|
||||
# 获取超图后的数据
|
||||
item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(related_items) > 0:
|
||||
items, item_hyper_edge_index = self._get_hypergraph(related_items, self.item_adj)
|
||||
sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(tot_item_embedding, items, item_hyper_edge_index, self.item_adj)
|
||||
raw_item_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index)
|
||||
item_embedding = raw_item_embedding
|
||||
# item_embedding = self._get_embedding(items, raw_item_embedding, item_tot2sub, self.item_adj)
|
||||
|
||||
entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(related_entities) > 0:
|
||||
entities, entity_hyper_edge_index = self._get_hypergraph(related_entities, self.entity_adj)
|
||||
sub_entity_embedding, sub_entity_edge_index, entity_tot2sub = self._before_hyperconv(tot_entity_embedding, entities, entity_hyper_edge_index, self.entity_adj)
|
||||
raw_entity_embedding = self.hyper_conv_entity(sub_entity_embedding, sub_entity_edge_index)
|
||||
entity_embedding = raw_entity_embedding
|
||||
# entity_embedding = self._get_embedding(entities, raw_entity_embedding, entity_tot2sub, self.entity_adj)
|
||||
|
||||
word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(related_words) > 0:
|
||||
owrds, word_hyper_edge_index = self._get_hypergraph(related_words, self.word_adj)
|
||||
sub_word_embedding, sub_word_edge_index, word_tot2sub = self._before_hyperconv(tot_word_embedding, owrds, word_hyper_edge_index, self.word_adj)
|
||||
raw_word_embedding = self.hyper_conv_word(sub_word_embedding, sub_word_edge_index)
|
||||
word_embedding = raw_word_embedding
|
||||
# word_embedding = self._get_embedding(owrds, raw_word_embedding, word_tot2sub, self.word_adj)
|
||||
|
||||
# 注意力机制
|
||||
if len(related_entities) == 0:
|
||||
user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, None)
|
||||
else:
|
||||
context_embedding = tot_entity_embedding[related_entities]
|
||||
user_repr = self._attention_and_gating(item_embedding, entity_embedding, word_embedding, context_embedding)
|
||||
return user_repr
|
||||
|
||||
def process_hllm(self, hllm_data, id_dict):
|
||||
res_data = []
|
||||
for raw_hyper_grapth in hllm_data:
|
||||
if not isinstance(raw_hyper_grapth, list):
|
||||
continue
|
||||
temp_hyper_grapth = []
|
||||
for meta_data in raw_hyper_grapth:
|
||||
if not isinstance(meta_data, int):
|
||||
continue
|
||||
if meta_data not in id_dict:
|
||||
continue
|
||||
temp_hyper_grapth.append(id_dict[meta_data])
|
||||
res_data.append(temp_hyper_grapth)
|
||||
return res_data
|
||||
|
||||
# 获取用户编码
|
||||
def encode_user(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
|
||||
user_repr_list = []
|
||||
for related_items, related_entities, related_words in zip(batch_related_items, batch_related_entities, batch_related_words):
|
||||
user_repr = self.encode_user_repr(related_items, related_entities, related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding)
|
||||
user_repr_list.append(user_repr)
|
||||
user_embedding = torch.stack(user_repr_list, dim=0)
|
||||
# print("user_embedding.shape", user_embedding.shape) # [6, 128]
|
||||
return user_embedding
|
||||
|
||||
# 推荐模块
|
||||
def recommend(self, batch, mode):
|
||||
# 获取数据
|
||||
conv_id = batch['conv_id']
|
||||
related_item = batch['related_item']
|
||||
related_entity = batch['related_entity']
|
||||
related_word = batch['related_word']
|
||||
item = batch['item']
|
||||
item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
|
||||
entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
|
||||
token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type)
|
||||
|
||||
# 获取用户编码
|
||||
# start = perf_counter()
|
||||
user_embedding = self.encode_user(
|
||||
related_item,
|
||||
related_entity,
|
||||
related_word,
|
||||
item_embedding,
|
||||
entity_embedding,
|
||||
token_embedding,
|
||||
) # (batch_size, emb_dim)
|
||||
# print(f"{perf_counter() - start:.2f}")
|
||||
|
||||
# 计算各实体得分
|
||||
scores = F.linear(user_embedding, entity_embedding, self.rec_bias.bias) # (batch_size, n_entity)
|
||||
loss = self.rec_loss(scores, item)
|
||||
return loss, scores
|
||||
|
||||
def _starts(self, batch_size):
|
||||
"""Return bsz start tokens."""
|
||||
return self.START.detach().expand(batch_size, 1)
|
||||
|
||||
def freeze_parameters(self):
|
||||
freeze_models = [
|
||||
self.entity_embedding,
|
||||
self.token_embedding,
|
||||
self.item_encoder,
|
||||
self.entity_encoder,
|
||||
self.word_encoder,
|
||||
self.hyper_conv_item,
|
||||
self.hyper_conv_entity,
|
||||
self.hyper_conv_word,
|
||||
self.item_attn,
|
||||
self.rec_bias
|
||||
]
|
||||
if self.pooling == "Attn":
|
||||
freeze_models.append(self.kg_attn)
|
||||
freeze_models.append(self.kg_attn_his)
|
||||
for model in freeze_models:
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def _before_hyperconv(self, embeddings:torch.FloatTensor, hypergraph_items:List[int], edge_index:torch.LongTensor, adj):
|
||||
sub_items = []
|
||||
edge_index = edge_index.cpu().numpy()
|
||||
for item in hypergraph_items:
|
||||
sub_items += [item] + list(adj.get(item, []))
|
||||
sub_items = list(set(sub_items))
|
||||
tot2sub = {tot:sub for sub, tot in enumerate(sub_items)}
|
||||
sub_embeddings = embeddings[sub_items]
|
||||
edge_index = [[tot2sub[v] for v in edge_index[0]], list(edge_index[1])]
|
||||
sub_edge_index = torch.tensor(edge_index).long()
|
||||
sub_edge_index = sub_edge_index.to(self.device)
|
||||
return sub_embeddings, sub_edge_index, tot2sub
|
||||
|
||||
# 获取超图后数据
|
||||
def encode_session(self, batch_related_items, batch_related_entities, batch_related_words, tot_item_embedding, tot_entity_embedding, tot_word_embedding):
|
||||
"""
|
||||
Return: session_repr (batch_size, batch_seq_len, token_emb_dim), mask (batch_size, batch_seq_len)
|
||||
"""
|
||||
session_repr_list = []
|
||||
for session_related_items, session_related_entities, session_related_words in zip(batch_related_items, batch_related_entities, batch_related_words):
|
||||
# COLD START
|
||||
# if len(session_related_items) == 0 or len(session_related_words) == 0:
|
||||
# if len(session_related_entities) == 0:
|
||||
# session_repr_list.append(None)
|
||||
# else:
|
||||
# session_repr = tot_entity_embedding[session_related_entities]
|
||||
# session_repr_list.append(session_repr)
|
||||
# continue
|
||||
|
||||
# 获取超图后的数据
|
||||
item_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(session_related_items) > 0:
|
||||
items, item_hyper_edge_index = self._get_hypergraph(session_related_items, self.item_adj)
|
||||
sub_item_embedding, sub_item_edge_index, item_tot2sub = self._before_hyperconv(tot_item_embedding, items, item_hyper_edge_index, self.item_adj)
|
||||
raw_item_embedding = self.hyper_conv_item(sub_item_embedding, sub_item_edge_index)
|
||||
item_embedding = raw_item_embedding
|
||||
# item_embedding = self._get_embedding(items, raw_item_embedding, item_tot2sub, self.item_adj)
|
||||
|
||||
entity_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(session_related_entities) > 0:
|
||||
entities, entity_hyper_edge_index = self._get_hypergraph(session_related_entities, self.entity_adj)
|
||||
sub_entity_embedding, sub_entity_edge_index, entity_tot2sub = self._before_hyperconv(tot_entity_embedding, entities, entity_hyper_edge_index, self.entity_adj)
|
||||
raw_entity_embedding = self.hyper_conv_entity(sub_entity_embedding, sub_entity_edge_index)
|
||||
entity_embedding = raw_entity_embedding
|
||||
# entity_embedding = self._get_embedding(entities, raw_entity_embedding, entity_tot2sub, self.entity_adj)
|
||||
|
||||
word_embedding = torch.zeros((1, self.kg_emb_dim), device=self.device)
|
||||
if len(session_related_words) > 0:
|
||||
owrds, word_hyper_edge_index = self._get_hypergraph(session_related_words, self.word_adj)
|
||||
sub_word_embedding, sub_word_edge_index, word_tot2sub = self._before_hyperconv(tot_word_embedding, owrds, word_hyper_edge_index, self.word_adj)
|
||||
raw_word_embedding = self.hyper_conv_word(sub_word_embedding, sub_word_edge_index)
|
||||
word_embedding = raw_word_embedding
|
||||
# word_embedding = self._get_embedding(owrds, raw_word_embedding, word_tot2sub, self.word_adj)
|
||||
|
||||
# 数据拼接
|
||||
if len(session_related_entities) == 0:
|
||||
session_repr = torch.cat((item_embedding, entity_embedding, word_embedding), dim=0)
|
||||
session_repr_list.append(session_repr)
|
||||
else:
|
||||
context_embedding = tot_entity_embedding[session_related_entities]
|
||||
session_repr = torch.cat((item_embedding, entity_embedding, context_embedding, word_embedding), dim=0)
|
||||
session_repr_list.append(session_repr)
|
||||
|
||||
batch_seq_len = max([session_repr.size(0) for session_repr in session_repr_list if session_repr is not None])
|
||||
mask_list = []
|
||||
for i in range(len(session_repr_list)):
|
||||
if session_repr_list[i] is None:
|
||||
mask_list.append([False] * batch_seq_len)
|
||||
zero_repr = torch.zeros((batch_seq_len, self.kg_emb_dim), device=self.device, dtype=torch.float)
|
||||
session_repr_list[i] = zero_repr
|
||||
else:
|
||||
mask_list.append([False] * (batch_seq_len - session_repr_list[i].size(0)) + [True] * session_repr_list[i].size(0))
|
||||
zero_repr = torch.zeros((batch_seq_len - session_repr_list[i].size(0), self.kg_emb_dim),
|
||||
device=self.device, dtype=torch.float)
|
||||
session_repr_list[i] = torch.cat((zero_repr, session_repr_list[i]), dim=0)
|
||||
|
||||
session_repr_embedding = torch.stack(session_repr_list, dim=0)
|
||||
session_repr_embedding = self.entity_to_token(session_repr_embedding)
|
||||
# print("session_repr_embedding.shape", session_repr_embedding.shape) # [6, 7, 300]
|
||||
return session_repr_embedding, torch.tensor(mask_list, device=self.device, dtype=torch.bool)
|
||||
|
||||
# 生成对话
|
||||
def decode_forced(self, related_encoder_state, context_encoder_state, session_state, user_embedding, resp):
|
||||
bsz = resp.size(0)
|
||||
seqlen = resp.size(1)
|
||||
inputs = resp.narrow(1, 0, seqlen - 1)
|
||||
inputs = torch.cat([self._starts(bsz), inputs], 1)
|
||||
latent, _ = self.decoder(inputs, related_encoder_state, context_encoder_state, session_state)
|
||||
token_logits = F.linear(latent, self.token_embedding.weight)
|
||||
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
|
||||
|
||||
user_latent = self.entity_to_token(user_embedding)
|
||||
user_latent = user_latent.unsqueeze(1).expand(-1, seqlen, -1)
|
||||
copy_latent = torch.cat((user_latent, latent), dim=-1)
|
||||
copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent)))
|
||||
if self.dataset == 'HReDial':
|
||||
copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial
|
||||
sum_logits = token_logits + user_logits + copy_logits
|
||||
_, preds = sum_logits.max(dim=-1)
|
||||
return sum_logits, preds
|
||||
|
||||
# 生成对话 - test
|
||||
def decode_greedy(self, related_encoder_state, context_encoder_state, session_state, user_embedding):
|
||||
bsz = context_encoder_state[0].shape[0]
|
||||
xs = self._starts(bsz)
|
||||
incr_state = None
|
||||
logits = []
|
||||
for i in range(self.longest_label):
|
||||
scores, incr_state = self.decoder(xs, related_encoder_state, context_encoder_state, session_state, incr_state) # incr_state is always None
|
||||
scores = scores[:, -1:, :]
|
||||
token_logits = F.linear(scores, self.token_embedding.weight)
|
||||
user_logits = self.user_proj_2(torch.relu(self.user_proj_1(user_embedding))).unsqueeze(1)
|
||||
|
||||
user_latent = self.entity_to_token(user_embedding)
|
||||
user_latent = user_latent.unsqueeze(1).expand(-1, 1, -1)
|
||||
copy_latent = torch.cat((user_latent, scores), dim=-1)
|
||||
copy_logits = self.copy_proj_2(torch.relu(self.copy_proj_1(copy_latent)))
|
||||
if self.dataset == 'HReDial':
|
||||
copy_logits = copy_logits * self.hredial_copy_mask.unsqueeze(0).unsqueeze(0) # not for tg-redial
|
||||
sum_logits = token_logits + user_logits + copy_logits
|
||||
probs, preds = sum_logits.max(dim=-1)
|
||||
logits.append(scores)
|
||||
xs = torch.cat([xs, preds], dim=1)
|
||||
# check if everyone has generated an end token
|
||||
all_finished = ((xs == self.end_token_idx).sum(dim=1) > 0).sum().item() == bsz
|
||||
if all_finished:
|
||||
break
|
||||
logits = torch.cat(logits, 1)
|
||||
return logits, xs
|
||||
|
||||
# 对话模块训练
|
||||
def converse(self, batch, mode):
|
||||
# 获取数据
|
||||
conv_id = batch['conv_id']
|
||||
related_item = batch['related_item']
|
||||
related_entity = batch['related_entity']
|
||||
related_word = batch['related_word']
|
||||
response = batch['response']
|
||||
|
||||
related_tokens = batch['related_tokens']
|
||||
context_tokens = batch['context_tokens']
|
||||
|
||||
item_embedding = self.item_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
|
||||
entity_embedding = self.entity_encoder(self.entity_embedding.weight, self.edge_idx, self.edge_type)
|
||||
token_embedding = self.word_encoder(self.word_embedding.weight, self.edge_idx, self.edge_type)
|
||||
|
||||
# 获取对话编码
|
||||
session_state = self.encode_session(
|
||||
related_item,
|
||||
related_entity,
|
||||
related_word,
|
||||
item_embedding,
|
||||
entity_embedding,
|
||||
token_embedding,
|
||||
)
|
||||
|
||||
# 获取用户编码
|
||||
# start = perf_counter()
|
||||
user_embedding = self.encode_user(
|
||||
related_item,
|
||||
related_entity,
|
||||
related_word,
|
||||
item_embedding,
|
||||
entity_embedding,
|
||||
token_embedding,
|
||||
) # (batch_size, emb_dim)
|
||||
|
||||
# 获取 X_c、X_h
|
||||
related_encoder_state = self.related_encoder(related_tokens)
|
||||
context_encoder_state = self.context_encoder(context_tokens)
|
||||
|
||||
# 对话生成
|
||||
if mode != 'test':
|
||||
self.longest_label = max(self.longest_label, response.shape[1])
|
||||
logits, preds = self.decode_forced(related_encoder_state, context_encoder_state, session_state, user_embedding, response)
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = response.view(-1)
|
||||
return self.conv_loss(logits, labels), preds
|
||||
else:
|
||||
_, preds = self.decode_greedy(related_encoder_state, context_encoder_state, session_state, user_embedding)
|
||||
return preds
|
||||
|
||||
# 推荐模块和对话模块分开训练
|
||||
def forward(self, batch, mode, stage):
|
||||
if len(self.gpu) >= 2:
|
||||
self.edge_idx = self.edge_idx.cuda(torch.cuda.current_device())
|
||||
self.edge_type = self.edge_type.cuda(torch.cuda.current_device())
|
||||
if stage == "conv":
|
||||
return self.converse(batch, mode)
|
||||
if stage == "rec":
|
||||
# start = perf_counter()
|
||||
res = self.recommend(batch, mode)
|
||||
# print(f"{perf_counter() - start:.2f}")
|
||||
return res
|
64
HyCoRec/crslab/model/pretrained_models.py
Normal file
64
HyCoRec/crslab/model/pretrained_models.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2021/1/6
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2021/1/7
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.download import DownloadableFile
|
||||
|
||||
"""Download links of pretrain models.
|
||||
|
||||
Now we provide the following models:
|
||||
|
||||
- `BERT`_: zh, en
|
||||
- `GPT2`_: zh, en
|
||||
|
||||
.. _BERT:
|
||||
https://www.aclweb.org/anthology/N19-1423/
|
||||
.. _GPT2:
|
||||
https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
|
||||
|
||||
"""
|
||||
|
||||
resources = {
|
||||
'bert': {
|
||||
'zh': {
|
||||
'version': '0.1',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EXm6uTgSkO1PgDD3TV9UtzMBfsAlJOun12vwB-hVkPRbXw?download=1',
|
||||
'bert_zh.zip',
|
||||
'e48ff2f3c2409bb766152dc5577cd5600838c9052622fd6172813dce31806ed3'
|
||||
)
|
||||
},
|
||||
'en': {
|
||||
'version': '0.1',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EfcnG_CkYAtKvEFUWvRF8i0BwmtCKnhnjOBwPW0W1tXqMQ?download=1',
|
||||
'bert_en.zip',
|
||||
'61b08202e8ad09088c9af78ab3f8902cd990813f6fa5b8b296d0da9d370006e3'
|
||||
)
|
||||
},
|
||||
},
|
||||
'gpt2': {
|
||||
'zh': {
|
||||
'version': '0.1',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/EdwPgkE_-_BCsVSqo4Ao9D8BKj6H_0wWGGxHxt_kPmoSwA?download=1',
|
||||
'gpt2_zh.zip',
|
||||
'5f366b729e509164bfd55026e6567e22e101bfddcfaac849bae96fc263c7de43'
|
||||
)
|
||||
},
|
||||
'en': {
|
||||
'version': '0.1',
|
||||
'file': DownloadableFile(
|
||||
'https://pkueducn-my.sharepoint.com/:u:/g/personal/franciszhou_pku_edu_cn/Ebe4PS0rYQ9InxmGvJ9JNXgBMI808ibQc93N-dAubtbTgQ?download=1',
|
||||
'gpt2_en.zip',
|
||||
'518c1c8a1868d4433d93688f2bf7f34b6216334395d1800d66308a80f4cac35e'
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
0
HyCoRec/crslab/model/utils/__init__.py
Normal file
0
HyCoRec/crslab/model/utils/__init__.py
Normal file
37
HyCoRec/crslab/model/utils/functions.py
Normal file
37
HyCoRec/crslab/model/utils/functions.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2020/11/26
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2020/11/16
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def edge_to_pyg_format(edge, type='RGCN'):
|
||||
if type == 'RGCN':
|
||||
edge_sets = torch.as_tensor(edge, dtype=torch.long)
|
||||
edge_idx = edge_sets[:, :2].t()
|
||||
edge_type = edge_sets[:, 2]
|
||||
return edge_idx, edge_type
|
||||
elif type == 'GCN':
|
||||
edge_set = [[co[0] for co in edge], [co[1] for co in edge]]
|
||||
return torch.as_tensor(edge_set, dtype=torch.long)
|
||||
else:
|
||||
raise NotImplementedError('type {} has not been implemented', type)
|
||||
|
||||
|
||||
def sort_for_packed_sequence(lengths: torch.Tensor):
|
||||
"""
|
||||
:param lengths: 1D array of lengths
|
||||
:return: sorted_lengths (lengths in descending order), sorted_idx (indices to sort), rev_idx (indices to retrieve original order)
|
||||
|
||||
"""
|
||||
sorted_idx = torch.argsort(lengths, descending=True) # idx to sort by length
|
||||
rev_idx = torch.argsort(sorted_idx) # idx to retrieve original order
|
||||
sorted_lengths = lengths[sorted_idx]
|
||||
|
||||
return sorted_lengths, sorted_idx, rev_idx
|
0
HyCoRec/crslab/model/utils/modules/__init__.py
Normal file
0
HyCoRec/crslab/model/utils/modules/__init__.py
Normal file
64
HyCoRec/crslab/model/utils/modules/attention.py
Normal file
64
HyCoRec/crslab/model/utils/modules/attention.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SelfAttentionBatch(nn.Module):
|
||||
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
|
||||
super(SelfAttentionBatch, self).__init__()
|
||||
self.dim = dim
|
||||
self.da = da
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
|
||||
nn.init.xavier_uniform_(self.a.data, gain=1.414)
|
||||
nn.init.xavier_uniform_(self.b.data, gain=1.414)
|
||||
|
||||
def forward(self, h):
|
||||
# h: (N, dim)
|
||||
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)
|
||||
attention = F.softmax(e, dim=0) # (N)
|
||||
return torch.matmul(attention, h) # (dim)
|
||||
|
||||
|
||||
class SelfAttentionSeq(nn.Module):
|
||||
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
|
||||
super(SelfAttentionSeq, self).__init__()
|
||||
self.dim = dim
|
||||
self.da = da
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
|
||||
nn.init.xavier_uniform_(self.a.data, gain=1.414)
|
||||
nn.init.xavier_uniform_(self.b.data, gain=1.414)
|
||||
|
||||
def forward(self, h, mask=None, return_logits=False):
|
||||
"""
|
||||
For the padding tokens, its corresponding mask is True
|
||||
if mask==[1, 1, 1, ...]
|
||||
"""
|
||||
# h: (batch, seq_len, dim), mask: (batch, seq_len)
|
||||
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) # (batch, seq_len, 1)
|
||||
if mask is not None:
|
||||
full_mask = -1e30 * mask.float()
|
||||
batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1) # for all padding one, the mask=0
|
||||
mask = full_mask * batch_mask
|
||||
e += mask.unsqueeze(-1)
|
||||
attention = F.softmax(e, dim=1) # (batch, seq_len, 1)
|
||||
# (batch, dim)
|
||||
if return_logits:
|
||||
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1)
|
||||
else:
|
||||
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1)
|
471
HyCoRec/crslab/model/utils/modules/transformer.py
Normal file
471
HyCoRec/crslab/model/utils/modules/transformer.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
"""Near infinity, useful as a large penalty for scoring when inf is bad."""
|
||||
NEAR_INF = 1e20
|
||||
NEAR_INF_FP16 = 65504
|
||||
|
||||
|
||||
def neginf(dtype):
|
||||
"""Returns a representable finite number near -inf for a dtype."""
|
||||
if dtype is torch.float16:
|
||||
return -NEAR_INF_FP16
|
||||
else:
|
||||
return -NEAR_INF
|
||||
|
||||
|
||||
def _create_selfattn_mask(x):
|
||||
# figure out how many timestamps we need
|
||||
bsz = x.size(0)
|
||||
time = x.size(1)
|
||||
# make sure that we don't look into the future
|
||||
mask = torch.tril(x.new(time, time).fill_(1))
|
||||
# broadcast across batch
|
||||
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
def create_position_codes(n_pos, dim, out):
|
||||
position_enc = np.array([
|
||||
[pos / np.power(10000, 2 * j / dim) for j in range(dim // 2)]
|
||||
for pos in range(n_pos)
|
||||
])
|
||||
|
||||
out.data[:, 0::2] = torch.as_tensor(np.sin(position_enc))
|
||||
out.data[:, 1::2] = torch.as_tensor(np.cos(position_enc))
|
||||
out.detach_()
|
||||
out.requires_grad = False
|
||||
|
||||
|
||||
def _normalize(tensor, norm_layer):
|
||||
"""Broadcast layer norm"""
|
||||
size = tensor.size()
|
||||
return norm_layer(tensor.view(-1, size[-1])).view(size)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_heads, dim, dropout=.0):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=dropout) # --attention-dropout
|
||||
self.q_lin = nn.Linear(dim, dim)
|
||||
self.k_lin = nn.Linear(dim, dim)
|
||||
self.v_lin = nn.Linear(dim, dim)
|
||||
# TODO: merge for the initialization step
|
||||
nn.init.xavier_normal_(self.q_lin.weight)
|
||||
nn.init.xavier_normal_(self.k_lin.weight)
|
||||
nn.init.xavier_normal_(self.v_lin.weight)
|
||||
# and set biases to 0
|
||||
self.out_lin = nn.Linear(dim, dim)
|
||||
|
||||
nn.init.xavier_normal_(self.out_lin.weight)
|
||||
|
||||
def forward(self, query, key=None, value=None, mask=None):
|
||||
# Input is [B, query_len, dim]
|
||||
# Mask is [B, key_len] (selfattn) or [B, key_len, key_len] (enc attn)
|
||||
batch_size, query_len, dim = query.size()
|
||||
assert dim == self.dim, \
|
||||
f'Dimensions do not match: {dim} query vs {self.dim} configured'
|
||||
assert mask is not None, 'Mask is None, please specify a mask'
|
||||
n_heads = self.n_heads
|
||||
dim_per_head = dim // n_heads
|
||||
scale = math.sqrt(dim_per_head)
|
||||
|
||||
def prepare_head(tensor):
|
||||
# input is [batch_size, seq_len, n_heads * dim_per_head]
|
||||
# output is [batch_size * n_heads, seq_len, dim_per_head]
|
||||
bsz, seq_len, _ = tensor.size()
|
||||
tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head)
|
||||
tensor = tensor.transpose(1, 2).contiguous().view(
|
||||
batch_size * n_heads,
|
||||
seq_len,
|
||||
dim_per_head
|
||||
)
|
||||
return tensor
|
||||
|
||||
# q, k, v are the transformed values
|
||||
if key is None and value is None:
|
||||
# self attention
|
||||
key = value = query
|
||||
elif value is None:
|
||||
# key and value are the same, but query differs
|
||||
# self attention
|
||||
value = key
|
||||
_, key_len, dim = key.size()
|
||||
|
||||
q = prepare_head(self.q_lin(query))
|
||||
k = prepare_head(self.k_lin(key))
|
||||
v = prepare_head(self.v_lin(value))
|
||||
|
||||
dot_prod = q.div_(scale).bmm(k.transpose(1, 2))
|
||||
# [B * n_heads, query_len, key_len]
|
||||
attn_mask = (
|
||||
(mask == 0)
|
||||
.view(batch_size, 1, -1, key_len)
|
||||
.repeat(1, n_heads, 1, 1)
|
||||
.expand(batch_size, n_heads, query_len, key_len)
|
||||
.view(batch_size * n_heads, query_len, key_len)
|
||||
)
|
||||
assert attn_mask.shape == dot_prod.shape
|
||||
dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype))
|
||||
|
||||
attn_weights = F.softmax(dot_prod, dim=-1).type_as(query)
|
||||
attn_weights = self.attn_dropout(attn_weights) # --attention-dropout
|
||||
|
||||
attentioned = attn_weights.bmm(v)
|
||||
attentioned = (
|
||||
attentioned.type_as(query)
|
||||
.view(batch_size, n_heads, query_len, dim_per_head)
|
||||
.transpose(1, 2).contiguous()
|
||||
.view(batch_size, query_len, dim)
|
||||
)
|
||||
|
||||
out = self.out_lin(attentioned)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerFFN(nn.Module):
|
||||
def __init__(self, dim, dim_hidden, relu_dropout=.0):
|
||||
super(TransformerFFN, self).__init__()
|
||||
self.relu_dropout = nn.Dropout(p=relu_dropout)
|
||||
self.lin1 = nn.Linear(dim, dim_hidden)
|
||||
self.lin2 = nn.Linear(dim_hidden, dim)
|
||||
nn.init.xavier_uniform_(self.lin1.weight)
|
||||
nn.init.xavier_uniform_(self.lin2.weight)
|
||||
# TODO: initialize biases to 0
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.lin1(x))
|
||||
x = self.relu_dropout(x) # --relu-dropout
|
||||
x = self.lin2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = embedding_size
|
||||
self.ffn_dim = ffn_size
|
||||
self.attention = MultiHeadAttention(
|
||||
n_heads, embedding_size,
|
||||
dropout=attention_dropout, # --attention-dropout
|
||||
)
|
||||
self.norm1 = nn.LayerNorm(embedding_size)
|
||||
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
|
||||
self.norm2 = nn.LayerNorm(embedding_size)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, tensor, mask):
|
||||
tensor = tensor + self.dropout(self.attention(tensor, mask=mask))
|
||||
tensor = _normalize(tensor, self.norm1)
|
||||
tensor = tensor + self.dropout(self.ffn(tensor))
|
||||
tensor = _normalize(tensor, self.norm2)
|
||||
tensor *= mask.unsqueeze(-1).type_as(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder module.
|
||||
|
||||
:param int n_heads: the number of multihead attention heads.
|
||||
:param int n_layers: number of transformer layers.
|
||||
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
|
||||
:param int ffn_size: the size of the hidden layer in the FFN
|
||||
:param embedding: an embedding matrix for the bottom layer of the transformer.
|
||||
If none, one is created for this encoder.
|
||||
:param float dropout: Dropout used around embeddings and before layer
|
||||
layer normalizations. This is used in Vaswani 2017 and works well on
|
||||
large datasets.
|
||||
:param float attention_dropout: Dropout performed after the multhead attention
|
||||
softmax. This is not used in Vaswani 2017.
|
||||
:param float relu_dropout: Dropout used after the ReLU in the FFN. Not used
|
||||
in Vaswani 2017, but used in Tensor2Tensor.
|
||||
:param int padding_idx: Reserved padding index in the embeddings matrix.
|
||||
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
|
||||
used. If on, position embeddings are learned from scratch.
|
||||
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
|
||||
Found useful in fairseq.
|
||||
:param bool reduction: If true, returns the mean vector for the entire encoding
|
||||
sequence.
|
||||
:param int n_positions: Size of the position embeddings matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
n_layers,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
vocabulary_size,
|
||||
embedding=None,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
padding_idx=0,
|
||||
learn_positional_embeddings=False,
|
||||
embeddings_scale=False,
|
||||
reduction=True,
|
||||
n_positions=1024
|
||||
):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
|
||||
self.embedding_size = embedding_size
|
||||
self.ffn_size = ffn_size
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.dim = embedding_size
|
||||
self.embeddings_scale = embeddings_scale
|
||||
self.reduction = reduction
|
||||
self.padding_idx = padding_idx
|
||||
# this is --dropout, not --relu-dropout or --attention-dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.out_dim = embedding_size
|
||||
assert embedding_size % n_heads == 0, \
|
||||
'Transformer embedding size must be a multiple of n_heads'
|
||||
|
||||
# check input formats:
|
||||
if embedding is not None:
|
||||
assert (
|
||||
embedding_size is None or embedding_size == embedding.weight.shape[1]
|
||||
), "Embedding dim must match the embedding size."
|
||||
|
||||
if embedding is not None:
|
||||
self.embeddings = embedding
|
||||
else:
|
||||
assert False
|
||||
assert padding_idx is not None
|
||||
self.embeddings = nn.Embedding(
|
||||
vocabulary_size, embedding_size, padding_idx=padding_idx
|
||||
)
|
||||
nn.init.normal_(self.embeddings.weight, 0, embedding_size ** -0.5)
|
||||
|
||||
# create the positional embeddings
|
||||
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
|
||||
if not learn_positional_embeddings:
|
||||
create_position_codes(
|
||||
n_positions, embedding_size, out=self.position_embeddings.weight
|
||||
)
|
||||
else:
|
||||
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
|
||||
|
||||
# build the model
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.layers.append(TransformerEncoderLayer(
|
||||
n_heads, embedding_size, ffn_size,
|
||||
attention_dropout=attention_dropout,
|
||||
relu_dropout=relu_dropout,
|
||||
dropout=dropout,
|
||||
))
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
input data is a FloatTensor of shape [batch, seq_len, dim]
|
||||
mask is a ByteTensor of shape [batch, seq_len], filled with 1 when
|
||||
inside the sequence and 0 outside.
|
||||
"""
|
||||
mask = input != self.padding_idx
|
||||
positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0)
|
||||
tensor = self.embeddings(input)
|
||||
if self.embeddings_scale:
|
||||
tensor = tensor * np.sqrt(self.dim)
|
||||
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
|
||||
# --dropout on the embeddings
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
tensor *= mask.unsqueeze(-1).type_as(tensor)
|
||||
for i in range(self.n_layers):
|
||||
tensor = self.layers[i](tensor, mask)
|
||||
|
||||
if self.reduction:
|
||||
divisor = mask.type_as(tensor).sum(dim=1).unsqueeze(-1).clamp(min=1e-7)
|
||||
output = tensor.sum(dim=1) / divisor
|
||||
return output
|
||||
else:
|
||||
output = tensor
|
||||
return output, mask
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = embedding_size
|
||||
self.ffn_dim = ffn_size
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.self_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm1 = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.encoder_attention = MultiHeadAttention(
|
||||
n_heads, embedding_size, dropout=attention_dropout
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(embedding_size)
|
||||
|
||||
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
|
||||
self.norm3 = nn.LayerNorm(embedding_size)
|
||||
|
||||
def forward(self, x, encoder_output, encoder_mask):
|
||||
decoder_mask = self._create_selfattn_mask(x)
|
||||
# first self attn
|
||||
residual = x
|
||||
# don't peak into the future!
|
||||
x = self.self_attention(query=x, mask=decoder_mask)
|
||||
x = self.dropout(x) # --dropout
|
||||
x = x + residual
|
||||
x = _normalize(x, self.norm1)
|
||||
|
||||
residual = x
|
||||
x = self.encoder_attention(
|
||||
query=x,
|
||||
key=encoder_output,
|
||||
value=encoder_output,
|
||||
mask=encoder_mask
|
||||
)
|
||||
x = self.dropout(x) # --dropout
|
||||
x = residual + x
|
||||
x = _normalize(x, self.norm2)
|
||||
|
||||
# finally the ffn
|
||||
residual = x
|
||||
x = self.ffn(x)
|
||||
x = self.dropout(x) # --dropout
|
||||
x = residual + x
|
||||
x = _normalize(x, self.norm3)
|
||||
|
||||
return x
|
||||
|
||||
def _create_selfattn_mask(self, x):
|
||||
# figure out how many timestamps we need
|
||||
bsz = x.size(0)
|
||||
time = x.size(1)
|
||||
# make sure that we don't look into the future
|
||||
mask = torch.tril(x.new(time, time).fill_(1))
|
||||
# broadcast across batch
|
||||
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
"""
|
||||
Transformer Decoder layer.
|
||||
|
||||
:param int n_heads: the number of multihead attention heads.
|
||||
:param int n_layers: number of transformer layers.
|
||||
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
|
||||
:param int ffn_size: the size of the hidden layer in the FFN
|
||||
:param embedding: an embedding matrix for the bottom layer of the transformer.
|
||||
If none, one is created for this encoder.
|
||||
:param float dropout: Dropout used around embeddings and before layer
|
||||
layer normalizations. This is used in Vaswani 2017 and works well on
|
||||
large datasets.
|
||||
:param float attention_dropout: Dropout performed after the multhead attention
|
||||
softmax. This is not used in Vaswani 2017.
|
||||
:param int padding_idx: Reserved padding index in the embeddings matrix.
|
||||
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
|
||||
used. If on, position embeddings are learned from scratch.
|
||||
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
|
||||
Found useful in fairseq.
|
||||
:param int n_positions: Size of the position embeddings matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
n_layers,
|
||||
embedding_size,
|
||||
ffn_size,
|
||||
vocabulary_size,
|
||||
embedding=None,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
relu_dropout=0.0,
|
||||
embeddings_scale=True,
|
||||
learn_positional_embeddings=False,
|
||||
padding_idx=None,
|
||||
n_positions=1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.ffn_size = ffn_size
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.dim = embedding_size
|
||||
self.embeddings_scale = embeddings_scale
|
||||
self.dropout = nn.Dropout(p=dropout) # --dropout
|
||||
|
||||
self.out_dim = embedding_size
|
||||
assert embedding_size % n_heads == 0, \
|
||||
'Transformer embedding size must be a multiple of n_heads'
|
||||
|
||||
self.embeddings = embedding
|
||||
|
||||
# create the positional embeddings
|
||||
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
|
||||
if not learn_positional_embeddings:
|
||||
create_position_codes(
|
||||
n_positions, embedding_size, out=self.position_embeddings.weight
|
||||
)
|
||||
else:
|
||||
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
|
||||
|
||||
# build the model
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.layers.append(TransformerDecoderLayer(
|
||||
n_heads, embedding_size, ffn_size,
|
||||
attention_dropout=attention_dropout,
|
||||
relu_dropout=relu_dropout,
|
||||
dropout=dropout,
|
||||
))
|
||||
|
||||
def forward(self, input, encoder_state, incr_state=None):
|
||||
encoder_output, encoder_mask = encoder_state
|
||||
|
||||
seq_len = input.shape[1]
|
||||
positions = input.new_empty(seq_len).long()
|
||||
positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (batch, seq_len)
|
||||
tensor = self.embeddings(input)
|
||||
if self.embeddings_scale:
|
||||
tensor = tensor * np.sqrt(self.dim)
|
||||
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
|
||||
tensor = self.dropout(tensor) # --dropout
|
||||
|
||||
for layer in self.layers:
|
||||
tensor = layer(tensor, encoder_output, encoder_mask)
|
||||
|
||||
return tensor, None
|
1
HyCoRec/crslab/quick_start/__init__.py
Normal file
1
HyCoRec/crslab/quick_start/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .quick_start import run_crslab
|
82
HyCoRec/crslab/quick_start/quick_start.py
Normal file
82
HyCoRec/crslab/quick_start/quick_start.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2021/1/8
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
# UPDATE
|
||||
# @Time : 2021/1/9
|
||||
# @Author : Xiaolei Wang
|
||||
# @email : wxl1999@foxmail.com
|
||||
|
||||
from crslab.config import Config
|
||||
from crslab.data import get_dataset, get_dataloader
|
||||
from crslab.system import get_system
|
||||
|
||||
|
||||
def run_crslab(config, save_data=False, restore_data=False, save_system=False, restore_system=False,
|
||||
interact=False, debug=False):
|
||||
"""A fast running api, which includes the complete process of training and testing models on specified datasets.
|
||||
|
||||
Args:
|
||||
config (Config or str): an instance of ``Config`` or path to the config file,
|
||||
which should be in ``yaml`` format. You can use default config provided in the `Github repo`_,
|
||||
or write it by yourself.
|
||||
save_data (bool): whether to save data. Defaults to False.
|
||||
restore_data (bool): whether to restore data. Defaults to False.
|
||||
save_system (bool): whether to save system. Defaults to False.
|
||||
restore_system (bool): whether to restore system. Defaults to False.
|
||||
interact (bool): whether to interact with the system. Defaults to False.
|
||||
debug (bool): whether to debug the system. Defaults to False.
|
||||
|
||||
.. _Github repo:
|
||||
https://github.com/RUCAIBox/CRSLab
|
||||
|
||||
"""
|
||||
# dataset & dataloader
|
||||
if isinstance(config['tokenize'], str):
|
||||
CRS_dataset = get_dataset(config, config['tokenize'], restore_data, save_data)
|
||||
side_data = CRS_dataset.side_data
|
||||
vocab = CRS_dataset.vocab
|
||||
|
||||
train_dataloader = get_dataloader(config, CRS_dataset.train_data, vocab)
|
||||
valid_dataloader = get_dataloader(config, CRS_dataset.valid_data, vocab)
|
||||
test_dataloader = get_dataloader(config, CRS_dataset.test_data, vocab)
|
||||
else:
|
||||
tokenized_dataset = {}
|
||||
train_dataloader = {}
|
||||
valid_dataloader = {}
|
||||
test_dataloader = {}
|
||||
vocab = {}
|
||||
side_data = {}
|
||||
|
||||
for task, tokenize in config['tokenize'].items():
|
||||
if tokenize in tokenized_dataset:
|
||||
dataset = tokenized_dataset[tokenize]
|
||||
else:
|
||||
dataset = get_dataset(config, tokenize, restore_data, save_data)
|
||||
tokenized_dataset[tokenize] = dataset
|
||||
train_data = dataset.train_data
|
||||
valid_data = dataset.valid_data
|
||||
test_data = dataset.test_data
|
||||
train_review = dataset.train_review
|
||||
valid_review = dataset.valid_review
|
||||
test_review = dataset.test_review
|
||||
side_data[task] = dataset.side_data
|
||||
vocab[task] = dataset.vocab
|
||||
review_id2entity = dataset.review_id2entity
|
||||
|
||||
train_dataloader[task] = get_dataloader(config, train_data, train_review, vocab[task], review_id2entity)
|
||||
valid_dataloader[task] = get_dataloader(config, valid_data, valid_review, vocab[task], review_id2entity)
|
||||
test_dataloader[task] = get_dataloader(config, test_data, test_review, vocab[task], review_id2entity)
|
||||
|
||||
# system
|
||||
CRS = get_system(config, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system,
|
||||
interact, debug)
|
||||
if interact:
|
||||
CRS.interact()
|
||||
else:
|
||||
CRS.fit()
|
||||
if save_system:
|
||||
CRS.save_model()
|
||||
|
||||
return
|
36
HyCoRec/crslab/system/__init__.py
Normal file
36
HyCoRec/crslab/system/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/29
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
# @Time : 2021/10/6
|
||||
# @Author : Zhipeng Zhao
|
||||
# @email : oran_official@outlook.com
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from .hycorec import HyCoRecSystem
|
||||
|
||||
system_register_table = {
|
||||
'HyCoRec': HyCoRecSystem,
|
||||
}
|
||||
|
||||
|
||||
def get_system(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
|
||||
interact=False, debug=False):
|
||||
"""
|
||||
return the system class
|
||||
"""
|
||||
model_name = opt['model_name']
|
||||
if model_name in system_register_table:
|
||||
system = system_register_table[model_name](opt, train_dataloader, valid_dataloader, test_dataloader, vocab,
|
||||
side_data, restore_system, interact, debug)
|
||||
logger.info(f'[Build system {model_name}]')
|
||||
return system
|
||||
else:
|
||||
raise NotImplementedError('The system with model [{}] in dataset [{}] has not been implemented'.
|
||||
format(model_name, opt['dataset']))
|
290
HyCoRec/crslab/system/base.py
Normal file
290
HyCoRec/crslab/system/base.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2021/1/9
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2021/11/5
|
||||
# @Author : Zhipeng Zhao
|
||||
# @Email : oran_official@outlook.com
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import optim
|
||||
from transformers import Adafactor
|
||||
|
||||
from crslab.config import SAVE_PATH
|
||||
from crslab.evaluator import get_evaluator
|
||||
from crslab.evaluator.metrics.base import AverageMetric
|
||||
from crslab.model import get_model
|
||||
from crslab.system.utils import lr_scheduler
|
||||
from crslab.system.utils.functions import compute_grad_norm
|
||||
|
||||
optim_class = {}
|
||||
optim_class.update({k: v for k, v in optim.__dict__.items() if not k.startswith('__') and k[0].isupper()})
|
||||
optim_class.update({'AdamW': optim.AdamW, 'Adafactor': Adafactor})
|
||||
lr_scheduler_class = {k: v for k, v in lr_scheduler.__dict__.items() if not k.startswith('__') and k[0].isupper()}
|
||||
transformers_tokenizer = ('bert', 'gpt2')
|
||||
|
||||
|
||||
class BaseSystem(ABC):
|
||||
"""Base class for all system"""
|
||||
|
||||
def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
|
||||
interact=False, debug=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
opt (dict): Indicating the hyper parameters.
|
||||
train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.
|
||||
valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.
|
||||
test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.
|
||||
vocab (dict): Indicating the vocabulary.
|
||||
side_data (dict): Indicating the side data.
|
||||
restore_system (bool, optional): Indicating if we store system after training. Defaults to False.
|
||||
interact (bool, optional): Indicating if we interact with system. Defaults to False.
|
||||
debug (bool, optional): Indicating if we train in debug mode. Defaults to False.
|
||||
|
||||
"""
|
||||
self.opt = opt
|
||||
if opt["gpu"] == [-1]:
|
||||
self.device = torch.device('cpu')
|
||||
elif len(opt["gpu"]) == 1:
|
||||
self.device = torch.device('cuda')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
# seed
|
||||
if 'seed' in opt:
|
||||
seed = int(opt['seed'])
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
logger.info(f'[Set seed] {seed}')
|
||||
# data
|
||||
if debug:
|
||||
self.train_dataloader = valid_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.test_dataloader = test_dataloader
|
||||
else:
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.test_dataloader = test_dataloader
|
||||
self.vocab = vocab
|
||||
self.side_data = side_data
|
||||
# model
|
||||
if 'model' in opt:
|
||||
self.model = get_model(opt, opt['model'], self.device, vocab, side_data).to(self.device)
|
||||
else:
|
||||
if 'rec_model' in opt:
|
||||
self.rec_model = get_model(opt, opt['rec_model'], self.device, vocab['rec'], side_data['rec']).to(
|
||||
self.device)
|
||||
if 'conv_model' in opt:
|
||||
self.conv_model = get_model(opt, opt['conv_model'], self.device, vocab['conv'], side_data['conv']).to(
|
||||
self.device)
|
||||
if 'policy_model' in opt:
|
||||
self.policy_model = get_model(opt, opt['policy_model'], self.device, vocab['policy'],
|
||||
side_data['policy']).to(self.device)
|
||||
model_file_name = opt.get('model_file', f'{opt["model_name"]}.pth')
|
||||
self.model_file = os.path.join(SAVE_PATH, model_file_name)
|
||||
if restore_system:
|
||||
self.restore_model()
|
||||
|
||||
if not interact:
|
||||
self.evaluator = get_evaluator('standard', opt['dataset'], opt['rankfile'])
|
||||
|
||||
def init_optim(self, opt, parameters):
|
||||
self.optim_opt = opt
|
||||
parameters = list(parameters)
|
||||
if isinstance(parameters[0], dict):
|
||||
for i, d in enumerate(parameters):
|
||||
parameters[i]['params'] = list(d['params'])
|
||||
|
||||
# gradient acumulation
|
||||
self.update_freq = opt.get('update_freq', 1)
|
||||
self._number_grad_accum = 0
|
||||
|
||||
self.gradient_clip = opt.get('gradient_clip', -1)
|
||||
|
||||
self.build_optimizer(parameters)
|
||||
self.build_lr_scheduler()
|
||||
|
||||
if isinstance(parameters[0], dict):
|
||||
self.parameters = []
|
||||
for d in parameters:
|
||||
self.parameters.extend(d['params'])
|
||||
else:
|
||||
self.parameters = parameters
|
||||
|
||||
# early stop
|
||||
self.need_early_stop = self.optim_opt.get('early_stop', False)
|
||||
if self.need_early_stop:
|
||||
logger.debug('[Enable early stop]')
|
||||
self.reset_early_stop_state()
|
||||
|
||||
def build_optimizer(self, parameters):
|
||||
optimizer_opt = self.optim_opt['optimizer']
|
||||
optimizer = optimizer_opt.pop('name')
|
||||
self.optimizer = optim_class[optimizer](parameters, **optimizer_opt)
|
||||
logger.info(f"[Build optimizer: {optimizer}]")
|
||||
|
||||
def build_lr_scheduler(self):
|
||||
"""
|
||||
Create the learning rate scheduler, and assign it to self.scheduler. This
|
||||
scheduler will be updated upon a call to receive_metrics. May also create
|
||||
self.warmup_scheduler, if appropriate.
|
||||
|
||||
:param state_dict states: Possible state_dict provided by model
|
||||
checkpoint, for restoring LR state
|
||||
:param bool hard_reset: If true, the LR scheduler should ignore the
|
||||
state dictionary.
|
||||
"""
|
||||
if self.optim_opt.get('lr_scheduler', None):
|
||||
lr_scheduler_opt = self.optim_opt['lr_scheduler']
|
||||
lr_scheduler = lr_scheduler_opt.pop('name')
|
||||
self.scheduler = lr_scheduler_class[lr_scheduler](self.optimizer, **lr_scheduler_opt)
|
||||
logger.info(f"[Build scheduler {lr_scheduler}]")
|
||||
|
||||
def reset_early_stop_state(self):
|
||||
self.best_valid = None
|
||||
self.drop_cnt = 0
|
||||
self.impatience = self.optim_opt.get('impatience', 3)
|
||||
if self.optim_opt['stop_mode'] == 'max':
|
||||
self.stop_mode = 1
|
||||
elif self.optim_opt['stop_mode'] == 'min':
|
||||
self.stop_mode = -1
|
||||
else:
|
||||
raise
|
||||
logger.debug('[Reset early stop state]')
|
||||
|
||||
@abstractmethod
|
||||
def fit(self):
|
||||
"""fit the whole system"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, batch, stage, mode):
|
||||
"""calculate loss and prediction for batch data under certrain stage and mode
|
||||
|
||||
Args:
|
||||
batch (dict or tuple): batch data
|
||||
stage (str): recommendation/policy/conversation etc.
|
||||
mode (str): train/valid/test
|
||||
"""
|
||||
pass
|
||||
|
||||
def backward(self, loss):
|
||||
"""empty grad, backward loss and update params
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor):
|
||||
"""
|
||||
self._zero_grad()
|
||||
|
||||
if self.update_freq > 1:
|
||||
self._number_grad_accum = (self._number_grad_accum + 1) % self.update_freq
|
||||
loss /= self.update_freq
|
||||
loss.backward(loss.clone().detach())
|
||||
|
||||
self._update_params()
|
||||
|
||||
def _zero_grad(self):
|
||||
if self._number_grad_accum != 0:
|
||||
# if we're accumulating gradients, don't actually zero things out yet.
|
||||
return
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def _update_params(self):
|
||||
if self.update_freq > 1:
|
||||
# we're doing gradient accumulation, so we don't only want to step
|
||||
# every N updates instead
|
||||
# self._number_grad_accum is updated in backward function
|
||||
if self._number_grad_accum != 0:
|
||||
return
|
||||
|
||||
if self.gradient_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters, self.gradient_clip
|
||||
)
|
||||
self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))
|
||||
self.evaluator.optim_metrics.add(
|
||||
'grad clip ratio',
|
||||
AverageMetric(float(grad_norm > self.gradient_clip)),
|
||||
)
|
||||
else:
|
||||
grad_norm = compute_grad_norm(self.parameters)
|
||||
self.evaluator.optim_metrics.add('grad norm', AverageMetric(grad_norm))
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
if hasattr(self, 'scheduler'):
|
||||
self.scheduler.train_step()
|
||||
|
||||
def adjust_lr(self, metric=None):
|
||||
"""adjust learning rate w/o metric by scheduler
|
||||
|
||||
Args:
|
||||
metric (optional): Defaults to None.
|
||||
"""
|
||||
if not hasattr(self, 'scheduler') or self.scheduler is None:
|
||||
return
|
||||
self.scheduler.valid_step(metric)
|
||||
logger.debug('[Adjust learning rate after valid epoch]')
|
||||
|
||||
def early_stop(self, metric):
|
||||
if not self.need_early_stop:
|
||||
return False
|
||||
if self.best_valid is None or metric * self.stop_mode > self.best_valid * self.stop_mode:
|
||||
self.best_valid = metric
|
||||
self.drop_cnt = 0
|
||||
logger.info('[Get new best model]')
|
||||
return False
|
||||
else:
|
||||
self.drop_cnt += 1
|
||||
if self.drop_cnt >= self.impatience:
|
||||
logger.info('[Early stop]')
|
||||
return True
|
||||
|
||||
def save_model(self):
|
||||
r"""Store the model parameters."""
|
||||
state = {}
|
||||
if hasattr(self, 'model'):
|
||||
state['model_state_dict'] = self.model.state_dict()
|
||||
if hasattr(self, 'rec_model'):
|
||||
state['rec_state_dict'] = self.rec_model.state_dict()
|
||||
if hasattr(self, 'conv_model'):
|
||||
state['conv_state_dict'] = self.conv_model.state_dict()
|
||||
if hasattr(self, 'policy_model'):
|
||||
state['policy_state_dict'] = self.policy_model.state_dict()
|
||||
|
||||
os.makedirs(SAVE_PATH, exist_ok=True)
|
||||
torch.save(state, self.model_file)
|
||||
logger.info(f'[Save model into {self.model_file}]')
|
||||
|
||||
def restore_model(self):
|
||||
r"""Store the model parameters."""
|
||||
if not os.path.exists(self.model_file):
|
||||
raise ValueError(f'Saved model [{self.model_file}] does not exist')
|
||||
checkpoint = torch.load(self.model_file, map_location=self.device)
|
||||
if hasattr(self, 'model'):
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if hasattr(self, 'rec_model'):
|
||||
self.rec_model.load_state_dict(checkpoint['rec_state_dict'])
|
||||
if hasattr(self, 'conv_model'):
|
||||
self.conv_model.load_state_dict(checkpoint['conv_state_dict'])
|
||||
if hasattr(self, 'policy_model'):
|
||||
self.policy_model.load_state_dict(checkpoint['policy_state_dict'])
|
||||
logger.info(f'[Restore model from {self.model_file}]')
|
||||
|
||||
@abstractmethod
|
||||
def interact(self):
|
||||
pass
|
176
HyCoRec/crslab/system/hycorec.py
Normal file
176
HyCoRec/crslab/system/hycorec.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Time : 2021/5/26
|
||||
# @Author : Chenzhan Shang
|
||||
# @email : czshang@outlook.com
|
||||
|
||||
import os
|
||||
import json
|
||||
from time import perf_counter
|
||||
import torch
|
||||
import pickle as pkl
|
||||
from loguru import logger
|
||||
|
||||
from crslab.evaluator.metrics.base import AverageMetric
|
||||
from crslab.evaluator.metrics.gen import PPLMetric
|
||||
from crslab.system.base import BaseSystem
|
||||
from crslab.system.utils.functions import ind2txt
|
||||
|
||||
|
||||
class HyCoRecSystem(BaseSystem):
|
||||
"""This is the system for KBRD model"""
|
||||
|
||||
def __init__(self, opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data, restore_system=False,
|
||||
interact=False, debug=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
opt (dict): Indicating the hyper parameters.
|
||||
train_dataloader (BaseDataLoader): Indicating the train dataloader of corresponding dataset.
|
||||
valid_dataloader (BaseDataLoader): Indicating the valid dataloader of corresponding dataset.
|
||||
test_dataloader (BaseDataLoader): Indicating the test dataloader of corresponding dataset.
|
||||
vocab (dict): Indicating the vocabulary.
|
||||
side_data (dict): Indicating the side data.
|
||||
restore_system (bool, optional): Indicating if we store system after training. Defaults to False.
|
||||
interact (bool, optional): Indicating if we interact with system. Defaults to False.
|
||||
debug (bool, optional): Indicating if we train in debug mode. Defaults to False.
|
||||
|
||||
"""
|
||||
super(HyCoRecSystem, self).__init__(opt, train_dataloader, valid_dataloader, test_dataloader, vocab, side_data,
|
||||
restore_system, interact, debug)
|
||||
|
||||
self.ind2tok = vocab['ind2tok']
|
||||
self.end_token_idx = vocab['tok2ind']['__end__']
|
||||
self.item_ids = side_data['item_entity_ids']
|
||||
|
||||
self.rec_optim_opt = opt['rec']
|
||||
self.conv_optim_opt = opt['conv']
|
||||
self.rec_epoch = self.rec_optim_opt['epoch']
|
||||
self.conv_epoch = self.conv_optim_opt['epoch']
|
||||
self.rec_batch_size = self.rec_optim_opt['batch_size']
|
||||
self.conv_batch_size = self.conv_optim_opt['batch_size']
|
||||
|
||||
def rec_evaluate(self, rec_predict, item_label):
|
||||
rec_predict = rec_predict.cpu()
|
||||
rec_predict = rec_predict[:, self.item_ids]
|
||||
_, rec_ranks = torch.topk(rec_predict, 50, dim=-1)
|
||||
rec_ranks = rec_ranks.tolist()
|
||||
item_label = item_label.tolist()
|
||||
# start = perf_counter()
|
||||
for rec_rank, label in zip(rec_ranks, item_label):
|
||||
label = self.item_ids.index(label)
|
||||
self.evaluator.rec_evaluate(rec_rank, label)
|
||||
# print(f"{perf_counter() - start}")
|
||||
|
||||
def conv_evaluate(self, prediction, response, batch_user_id=None, batch_conv_id=None):
|
||||
prediction = prediction.tolist()
|
||||
response = response.tolist()
|
||||
if batch_user_id is None:
|
||||
for p, r in zip(prediction, response):
|
||||
p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
|
||||
r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
|
||||
self.evaluator.gen_evaluate(p_str, [r_str], p)
|
||||
else:
|
||||
for p, r, uid, cid in zip(prediction, response, batch_user_id, batch_conv_id):
|
||||
p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
|
||||
r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
|
||||
self.evaluator.gen_evaluate(p_str, [r_str], p)
|
||||
|
||||
def step(self, batch, stage, mode):
|
||||
assert stage in ('rec', 'conv')
|
||||
assert mode in ('train', 'valid', 'test')
|
||||
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device)
|
||||
|
||||
if stage == 'rec':
|
||||
rec_loss, rec_scores = self.model.forward(batch, mode, stage)
|
||||
rec_loss = rec_loss.sum()
|
||||
if mode == 'train':
|
||||
self.backward(rec_loss)
|
||||
else:
|
||||
self.rec_evaluate(rec_scores, batch['item'])
|
||||
rec_loss = rec_loss.item()
|
||||
self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss))
|
||||
else:
|
||||
if mode != 'test':
|
||||
gen_loss, preds = self.model.forward(batch, mode, stage)
|
||||
if mode == 'train':
|
||||
self.backward(gen_loss)
|
||||
else:
|
||||
self.conv_evaluate(preds, batch['response'])
|
||||
gen_loss = gen_loss.item()
|
||||
self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss))
|
||||
self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
|
||||
else:
|
||||
preds = self.model.forward(batch, mode, stage)
|
||||
self.conv_evaluate(preds, batch['response'], batch.get('user_id', None), batch['conv_id'])
|
||||
|
||||
def train_recommender(self):
|
||||
self.init_optim(self.rec_optim_opt, self.model.parameters())
|
||||
|
||||
for epoch in range(self.rec_epoch):
|
||||
self.evaluator.reset_metrics()
|
||||
logger.info(f'[Recommendation epoch {str(epoch)}]')
|
||||
logger.info('[Train]')
|
||||
for batch in self.train_dataloader.get_rec_data(self.rec_batch_size):
|
||||
self.step(batch, stage='rec', mode='train')
|
||||
self.evaluator.report(epoch=epoch, mode='train')
|
||||
# val
|
||||
logger.info('[Valid]')
|
||||
with torch.no_grad():
|
||||
self.evaluator.reset_metrics()
|
||||
for batch in self.valid_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):
|
||||
self.step(batch, stage='rec', mode='valid')
|
||||
self.evaluator.report(epoch=epoch, mode='valid')
|
||||
# early stop
|
||||
metric = self.evaluator.optim_metrics['rec_loss']
|
||||
if self.early_stop(metric):
|
||||
break
|
||||
# test
|
||||
logger.info('[Test]')
|
||||
with torch.no_grad():
|
||||
self.evaluator.reset_metrics()
|
||||
for batch in self.test_dataloader.get_rec_data(self.rec_batch_size, shuffle=False):
|
||||
self.step(batch, stage='rec', mode='test')
|
||||
self.evaluator.report(mode='test')
|
||||
|
||||
def train_conversation(self):
|
||||
if os.environ["CUDA_VISIBLE_DEVICES"] == '-1':
|
||||
self.model.freeze_parameters()
|
||||
else:
|
||||
self.model.module.freeze_parameters()
|
||||
self.init_optim(self.conv_optim_opt, self.model.parameters())
|
||||
|
||||
for epoch in range(self.conv_epoch):
|
||||
self.evaluator.reset_metrics()
|
||||
logger.info(f'[Conversation epoch {str(epoch)}]')
|
||||
logger.info('[Train]')
|
||||
for batch in self.train_dataloader.get_conv_data(batch_size=self.conv_batch_size):
|
||||
self.step(batch, stage='conv', mode='train')
|
||||
self.evaluator.report(epoch=epoch, mode='train')
|
||||
# val
|
||||
logger.info('[Valid]')
|
||||
with torch.no_grad():
|
||||
self.evaluator.reset_metrics()
|
||||
for batch in self.valid_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):
|
||||
self.step(batch, stage='conv', mode='valid')
|
||||
self.evaluator.report(epoch=epoch, mode='valid')
|
||||
# early stop
|
||||
metric = self.evaluator.optim_metrics['gen_loss']
|
||||
if self.early_stop(metric):
|
||||
break
|
||||
# test
|
||||
logger.info('[Test]')
|
||||
with torch.no_grad():
|
||||
self.evaluator.reset_metrics()
|
||||
for batch in self.test_dataloader.get_conv_data(batch_size=self.conv_batch_size, shuffle=False):
|
||||
self.step(batch, stage='conv', mode='test')
|
||||
self.evaluator.report(mode='test')
|
||||
|
||||
def fit(self):
|
||||
self.train_recommender()
|
||||
self.train_conversation()
|
||||
|
||||
def interact(self):
|
||||
pass
|
0
HyCoRec/crslab/system/utils/__init__.py
Normal file
0
HyCoRec/crslab/system/utils/__init__.py
Normal file
66
HyCoRec/crslab/system/utils/functions.py
Normal file
66
HyCoRec/crslab/system/utils/functions.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2020/12/18
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2021/10/05
|
||||
# @Author : Zhipeng Zhao
|
||||
# @email : oran_official@outlook.com
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def compute_grad_norm(parameters, norm_type=2.0):
|
||||
"""
|
||||
Compute norm over gradients of model parameters.
|
||||
|
||||
:param parameters:
|
||||
the model parameters for gradient norm calculation. Iterable of
|
||||
Tensors or single Tensor
|
||||
:param norm_type:
|
||||
type of p-norm to use
|
||||
|
||||
:returns:
|
||||
the computed gradient norm
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = [p for p in parameters if p is not None and p.grad is not None]
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
return total_norm ** (1.0 / norm_type)
|
||||
|
||||
|
||||
def ind2txt(inds, ind2tok, end_token_idx=None, unk_token='unk'):
|
||||
sentence = []
|
||||
for ind in inds:
|
||||
if isinstance(ind, torch.Tensor):
|
||||
ind = ind.item()
|
||||
if end_token_idx and ind == end_token_idx:
|
||||
break
|
||||
sentence.append(ind2tok.get(ind, unk_token))
|
||||
return ' '.join(sentence)
|
||||
|
||||
def ind2txt_with_slots(inds,slots,ind2tok, end_token_idx=None, unk_token='unk',slot_token='[ITEM]'):
|
||||
sentence = []
|
||||
for ind in inds:
|
||||
if isinstance(ind, torch.Tensor):
|
||||
ind = ind.item()
|
||||
if end_token_idx and ind == end_token_idx:
|
||||
break
|
||||
token = ind2tok.get(ind, unk_token)
|
||||
if token == slot_token:
|
||||
token = slots[0]
|
||||
slots = slots[1:]
|
||||
sentence.append(token)
|
||||
return ' '.join(sentence)
|
||||
|
||||
def ind2slot(inds,ind2slot):
|
||||
return [ ind2slot[ind] for ind in inds]
|
316
HyCoRec/crslab/system/utils/lr_scheduler.py
Normal file
316
HyCoRec/crslab/system/utils/lr_scheduler.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# @Time : 2020/12/1
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/12/14
|
||||
# @Author : Xiaolei Wang
|
||||
# @Email : wxl1999@foxmail.com
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import optim
|
||||
|
||||
|
||||
class LRScheduler(ABC):
|
||||
"""
|
||||
Class for LR Schedulers.
|
||||
|
||||
Includes some basic functionality by default - setting up the warmup
|
||||
scheduler, passing the correct number of steps to train_step, loading and
|
||||
saving states.
|
||||
Subclasses must implement abstract methods train_step() and valid_step().
|
||||
Schedulers should be initialized with lr_scheduler_factory().
|
||||
__init__() should not be called directly.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_steps: int = 0):
|
||||
"""
|
||||
Initialize warmup scheduler. Specific main schedulers should be initialized in
|
||||
the subclasses. Do not invoke this method diretly.
|
||||
|
||||
:param optimizer optimizer:
|
||||
Optimizer being used for training. May be wrapped in
|
||||
fp16_optimizer_wrapper depending on whether fp16 is used.
|
||||
:param int warmup_steps:
|
||||
Number of training step updates warmup scheduler should take.
|
||||
"""
|
||||
self._number_training_updates = 0
|
||||
self.warmup_steps = warmup_steps
|
||||
self._init_warmup_scheduler(optimizer)
|
||||
|
||||
def _warmup_lr(self, step):
|
||||
"""
|
||||
Return lr multiplier (on initial lr) for warmup scheduler.
|
||||
"""
|
||||
return float(step) / float(max(1, self.warmup_steps))
|
||||
|
||||
def _init_warmup_scheduler(self, optimizer):
|
||||
if self.warmup_steps > 0:
|
||||
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, self._warmup_lr)
|
||||
else:
|
||||
self.warmup_scheduler = None
|
||||
|
||||
def _is_lr_warming_up(self):
|
||||
"""
|
||||
Check if we're warming up the learning rate.
|
||||
"""
|
||||
return (
|
||||
hasattr(self, 'warmup_scheduler')
|
||||
and self.warmup_scheduler is not None
|
||||
and self._number_training_updates <= self.warmup_steps
|
||||
)
|
||||
|
||||
def train_step(self):
|
||||
"""
|
||||
Use the number of train steps to adjust the warmup scheduler or the main
|
||||
scheduler, depending on where in training we are.
|
||||
|
||||
Override this method to override the behavior for training schedulers.
|
||||
"""
|
||||
self._number_training_updates += 1
|
||||
if self._is_lr_warming_up():
|
||||
self.warmup_scheduler.step()
|
||||
else:
|
||||
self.train_adjust()
|
||||
|
||||
def valid_step(self, metric=None):
|
||||
if self._is_lr_warming_up():
|
||||
# we're not done warming up, so don't start using validation
|
||||
# metrics to adjust schedule
|
||||
return
|
||||
self.valid_adjust(metric)
|
||||
|
||||
@abstractmethod
|
||||
def train_adjust(self):
|
||||
"""
|
||||
Use the number of train steps to decide when to adjust LR schedule.
|
||||
|
||||
Override this method to override the behavior for training schedulers.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def valid_adjust(self, metric):
|
||||
"""
|
||||
Use the metrics to decide when to adjust LR schedule.
|
||||
|
||||
This uses the loss as the validation metric if present, if not this
|
||||
function does nothing. Note that the model must be reporting loss for
|
||||
this to work.
|
||||
|
||||
Override this method to override the behavior for validation schedulers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ReduceLROnPlateau(LRScheduler):
|
||||
"""
|
||||
Scheduler that decays by a multiplicative rate when valid loss plateaus.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,
|
||||
threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, warmup_steps=0):
|
||||
super(ReduceLROnPlateau, self).__init__(optimizer, warmup_steps)
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor,
|
||||
patience=patience, threshold=threshold,
|
||||
threshold_mode=threshold_mode, cooldown=cooldown,
|
||||
min_lr=min_lr, eps=eps)
|
||||
|
||||
def train_adjust(self):
|
||||
pass
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
self.scheduler.step(metric)
|
||||
|
||||
|
||||
class StepLR(LRScheduler):
|
||||
"""
|
||||
Scheduler that decays by a fixed multiplicative rate at each valid step.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, warmup_steps=0):
|
||||
super(StepLR, self).__init__(optimizer, warmup_steps)
|
||||
self.scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma, last_epoch)
|
||||
|
||||
def train_adjust(self):
|
||||
pass
|
||||
|
||||
def valid_adjust(self, metric=None):
|
||||
self.scheduler.step()
|
||||
|
||||
|
||||
class ConstantLR(LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0):
|
||||
super(ConstantLR, self).__init__(optimizer, warmup_steps)
|
||||
|
||||
def train_adjust(self):
|
||||
pass
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class InvSqrtLR(LRScheduler):
|
||||
"""
|
||||
Scheduler that decays at an inverse square root rate.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, invsqrt_lr_decay_gamma=-1, last_epoch=-1, warmup_steps=0):
|
||||
"""
|
||||
invsqrt_lr_decay_gamma determines the cycle length of the inverse square root
|
||||
scheduler.
|
||||
|
||||
When steps taken == invsqrt_lr_decay_gamma, the lr multiplier is 1
|
||||
"""
|
||||
super(InvSqrtLR, self).__init__(optimizer, warmup_steps)
|
||||
self.invsqrt_lr_decay_gamma = invsqrt_lr_decay_gamma
|
||||
if invsqrt_lr_decay_gamma <= 0:
|
||||
logger.warning(
|
||||
'--lr-scheduler invsqrt requires a value for '
|
||||
'--invsqrt-lr-decay-gamma. Defaulting to set gamma to '
|
||||
'--warmup-updates value for backwards compatibility.'
|
||||
)
|
||||
self.invsqrt_lr_decay_gamma = self.warmup_steps
|
||||
|
||||
self.decay_factor = np.sqrt(max(1, self.invsqrt_lr_decay_gamma))
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._invsqrt_lr, last_epoch)
|
||||
|
||||
def _invsqrt_lr(self, step):
|
||||
return self.decay_factor / np.sqrt(max(1, self.invsqrt_lr_decay_gamma + step))
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
# this is a training step lr scheduler, nothing to adjust in validation
|
||||
pass
|
||||
|
||||
|
||||
class CosineAnnealingLR(LRScheduler):
|
||||
"""
|
||||
Scheduler that decays by a cosine function.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, warmup_steps=0):
|
||||
"""
|
||||
training_steps determines the cycle length of the cosine annealing.
|
||||
|
||||
It indicates the number of steps from 1.0 multiplier to 0.0, which corresponds
|
||||
to going from cos(0) to cos(pi)
|
||||
"""
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_steps)
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min, last_epoch)
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class CosineAnnealingWarmRestartsLR(LRScheduler):
|
||||
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warmup_steps=0):
|
||||
super(CosineAnnealingWarmRestartsLR, self).__init__(optimizer, warmup_steps)
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult, eta_min, last_epoch)
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class TransformersLinearLR(LRScheduler):
|
||||
"""
|
||||
Scheduler that decays linearly.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, training_steps, warmup_steps=0):
|
||||
"""
|
||||
training_steps determines the cycle length of the linear annealing.
|
||||
|
||||
It indicates the number of steps from 1.0 multiplier to 0.0
|
||||
"""
|
||||
super(TransformersLinearLR, self).__init__(optimizer, warmup_steps)
|
||||
self.training_steps = training_steps - warmup_steps
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._linear_lr)
|
||||
|
||||
def _linear_lr(self, step):
|
||||
return max(0.0, float(self.training_steps - step) / float(max(1, self.training_steps)))
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class TransformersCosineLR(LRScheduler):
|
||||
def __init__(self, optimizer, training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1,
|
||||
warmup_steps: int = 0):
|
||||
super(TransformersCosineLR, self).__init__(optimizer, warmup_steps)
|
||||
self.training_steps = training_steps - warmup_steps
|
||||
self.num_cycles = num_cycles
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_lr, last_epoch)
|
||||
|
||||
def _cosine_lr(self, step):
|
||||
progress = float(step) / float(max(1, self.training_steps))
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class TransformersCosineWithHardRestartsLR(LRScheduler):
|
||||
def __init__(self, optimizer, training_steps: int, num_cycles: int = 1, last_epoch: int = -1,
|
||||
warmup_steps: int = 0):
|
||||
super(TransformersCosineWithHardRestartsLR, self).__init__(optimizer, warmup_steps)
|
||||
self.training_steps = training_steps - warmup_steps
|
||||
self.num_cycles = num_cycles
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._cosine_with_hard_restarts_lr, last_epoch)
|
||||
|
||||
def _cosine_with_hard_restarts_lr(self, step):
|
||||
progress = float(step) / float(max(1, self.training_steps))
|
||||
if progress >= 1.0:
|
||||
return 0.0
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(self.num_cycles) * progress) % 1.0))))
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
||||
|
||||
|
||||
class TransformersPolynomialDecayLR(LRScheduler):
|
||||
def __init__(self, optimizer, training_steps, lr_end=1e-7, power=1.0, last_epoch=-1, warmup_steps=0):
|
||||
super(TransformersPolynomialDecayLR, self).__init__(optimizer, warmup_steps)
|
||||
self.training_steps = training_steps - warmup_steps
|
||||
self.lr_init = optimizer.defaults["lr"]
|
||||
self.lr_end = lr_end
|
||||
assert self.lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({self.lr_init})"
|
||||
self.power = power
|
||||
self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, self._polynomial_decay_lr, last_epoch)
|
||||
|
||||
def _polynomial_decay_lr(self, step):
|
||||
if step > self.training_steps:
|
||||
return self.lr_end / self.lr_init # as LambdaLR multiplies by lr_init
|
||||
else:
|
||||
lr_range = self.lr_init - self.lr_end
|
||||
decay_steps = self.training_steps
|
||||
pct_remaining = 1 - step / decay_steps
|
||||
decay = lr_range * pct_remaining ** self.power + self.lr_end
|
||||
return decay / self.lr_init # as LambdaLR multiplies by lr_init
|
||||
|
||||
def train_adjust(self):
|
||||
self.scheduler.step()
|
||||
|
||||
def valid_adjust(self, metric):
|
||||
pass
|
45
HyCoRec/run_crslab.py
Normal file
45
HyCoRec/run_crslab.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# @Time : 2020/11/22
|
||||
# @Author : Kun Zhou
|
||||
# @Email : francis_kun_zhou@163.com
|
||||
|
||||
# UPDATE:
|
||||
# @Time : 2020/11/24, 2021/1/9
|
||||
# @Author : Kun Zhou, Xiaolei Wang
|
||||
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from crslab.config import Config
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
if __name__ == '__main__':
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str,
|
||||
default='config/crs/hycorec/hredial.yaml', help='config file(yaml) path')
|
||||
parser.add_argument('-g', '--gpu', type=str, default='-1',
|
||||
help='specify GPU id(s) to use, we now support multiple GPUs. Defaults to CPU(-1).')
|
||||
parser.add_argument('-sd', '--save_data', action='store_true',
|
||||
help='save processed dataset')
|
||||
parser.add_argument('-rd', '--restore_data', action='store_true',
|
||||
help='restore processed dataset')
|
||||
parser.add_argument('-ss', '--save_system', action='store_true',
|
||||
help='save trained system')
|
||||
parser.add_argument('-rs', '--restore_system', action='store_true',
|
||||
help='restore trained system')
|
||||
parser.add_argument('-d', '--debug', action='store_true',
|
||||
help='use valid dataset to debug your system')
|
||||
parser.add_argument('-i', '--interact', action='store_true',
|
||||
help='interact with your system instead of training')
|
||||
parser.add_argument('-s', '--seed', type=int, default=2020)
|
||||
parser.add_argument('-p', '--pretrain', action='store_true', help='use pretrain weights')
|
||||
parser.add_argument('-e', '--pretrain_epoch', type=int, default=9999, help='pretrain epoch')
|
||||
args, _ = parser.parse_known_args()
|
||||
config = Config(args.config, args.gpu, args.debug, args.seed, args.pretrain, args.pretrain_epoch)
|
||||
|
||||
from crslab.quick_start import run_crslab
|
||||
|
||||
run_crslab(config, args.save_data, args.restore_data, args.save_system, args.restore_system, args.interact,
|
||||
args.debug)
|
Reference in New Issue
Block a user