31 #ifndef SRC_TESTS_PR2018_UTIL_HPP_ 32 #define SRC_TESTS_PR2018_UTIL_HPP_ 34 #define GXL_GEDLIB_SHARED 35 #include "../../../src/env/ged_env.hpp" 39 bool is_synth_mol_dataset(
const std::string & dataset) {
40 return ((dataset ==
"S-MOL_NL01") or (dataset ==
"S-MOL_NL04") or (dataset ==
"S-MOL_NL07") or (dataset ==
"S-MOL_NL10"));
43 bool is_chemical_dataset(
const std::string & dataset) {
44 return (is_synth_mol_dataset(dataset) or (dataset ==
"AIDS") or (dataset ==
"Mutagenicity") or (dataset ==
"acyclic") or (dataset ==
"alkane") or (dataset ==
"mao") or (dataset ==
"pah") );
47 bool is_letter_dataset(
const std::string & dataset) {
48 return ((dataset ==
"Letter_HIGH") or (dataset ==
"Letter_LOW") or (dataset ==
"Letter_MED"));
51 void check_dataset(
const std::string & dataset) {
52 if (not (is_chemical_dataset(dataset) or is_letter_dataset(dataset) or (dataset ==
"CMU-GED") or (dataset ==
"Fingerprint") or (dataset ==
"GREC") or (dataset ==
"Protein"))) {
53 throw ged::Error(std::string(
"Dataset \"") + dataset +
"\" does not exists.");
57 std::string graph_dir(
const std::string & dataset) {
58 std::string root_dir(
"../../../data/datasets/");
59 if ((dataset ==
"AIDS") or (dataset ==
"Fingerprint") or (dataset ==
"GREC") or (dataset ==
"Protein") or (dataset ==
"Mutagenicity")) {
60 return (root_dir + dataset +
"/data/");
62 else if ((dataset ==
"Letter_HIGH")) {
63 return (root_dir +
"Letter/HIGH/");
65 else if ((dataset ==
"Letter_LOW")) {
66 return (root_dir +
"Letter/LOW/");
68 else if ((dataset ==
"Letter_MED")) {
69 return (root_dir +
"Letter/MED/");
71 else if (dataset ==
"S-MOL_NL01") {
72 return (root_dir +
"S-MOL/NL01/");
74 else if (dataset ==
"S-MOL_NL04") {
75 return (root_dir +
"S-MOL/NL04/");
77 else if (dataset ==
"S-MOL_NL07") {
78 return (root_dir +
"S-MOL/NL07/");
80 else if (dataset ==
"S-MOL_NL10") {
81 return (root_dir +
"S-MOL/NL10/");
83 else if (dataset ==
"CMU-GED") {
84 return (root_dir + dataset +
"/CMU/");
86 else if ((dataset ==
"acyclic") or (dataset ==
"alkane") or (dataset ==
"mao") or (dataset ==
"pah")) {
87 return (root_dir + dataset +
"/");
90 throw ged::Error(std::string(
"Dataset \"") + dataset +
"\" does not exists.");
95 std::string train_collection(
const std::string & dataset) {
96 std::string root_dir(
"../collections/");
97 check_dataset(dataset);
98 if (is_letter_dataset(dataset)) {
99 return (root_dir +
"Letter_50.xml");
101 if (is_synth_mol_dataset(dataset)) {
102 return (root_dir +
"S-MOL_50.xml");
104 return root_dir + dataset +
"_50.xml";
107 std::string test_collection(
const std::string & dataset) {
108 std::string root_dir(
"../collections/");
109 check_dataset(dataset);
110 if (is_letter_dataset(dataset)) {
111 return (root_dir +
"Letter_100.xml");
113 if (is_synth_mol_dataset(dataset)) {
114 return (root_dir +
"S-MOL_100.xml");
116 return root_dir + dataset +
"_100.xml";
119 std::string config_prefix(
const std::string & dataset) {
120 check_dataset(dataset);
121 return std::string(
"../output/" + dataset +
"_");
124 std::string init_options(
const std::string & dataset,
const std::string & config_suffix,
const std::string & data_suffix =
"",
bool save_train =
false,
bool load_train =
false, std::size_t threads = 8) {
125 check_dataset(dataset);
126 std::string options(
"--threads ");
127 options += std::to_string(threads) +
" --save ../output/";
128 options += dataset +
"_" + config_suffix +
".ini";
131 throw ged::Error(
"Training data cannot be both saved and loaded.");
133 options +=
" --save-train ../output/" + dataset +
"_" + data_suffix +
".data";
136 options +=
" --load-train ../output/" + dataset +
"_" + data_suffix +
".data";
141 std::string ground_truth_option(
const std::string & dataset) {
142 check_dataset(dataset);
146 return std::string(
" --ground-truth-method IPFP");
150 if (is_chemical_dataset(dataset)) {
153 else if (is_letter_dataset(dataset)) {
156 else if (dataset ==
"CMU-GED") {
159 else if (dataset ==
"Fingerprint") {
162 else if (dataset ==
"GREC") {
165 else if (dataset ==
"Protein") {
169 throw ged::Error(std::string(
"Dataset \"") + dataset +
"\" does not exists.");
175 check_dataset(dataset);
176 if ((dataset ==
"Fingerprint") or (dataset ==
"CMU-GED")) {
183 check_dataset(dataset);
184 if (is_letter_dataset(dataset)) {
190 std::unordered_set<std::string> irrelevant_node_attributes(
const std::string & dataset) {
191 check_dataset(dataset);
192 std::unordered_set<std::string> irrelevant_attributes;
193 if ((dataset ==
"AIDS")) {
194 irrelevant_attributes.insert({
"x",
"y",
"symbol"});
196 else if (dataset ==
"Protein") {
197 irrelevant_attributes.insert(
"aaLength");
199 return irrelevant_attributes;
202 std::unordered_set<std::string> irrelevant_edge_attributes(
const std::string & dataset) {
203 check_dataset(dataset);
204 std::unordered_set<std::string> irrelevant_attributes;
205 if ((dataset ==
"GREC")) {
206 irrelevant_attributes.insert({
"angle0",
"angle1"});
208 else if (dataset ==
"Protein") {
209 irrelevant_attributes.insert({
"distance0",
"distance1"});
211 else if (dataset ==
"Fingerprint") {
212 irrelevant_attributes.insert(
"angle");
214 return irrelevant_attributes;
218 if (is_chemical_dataset(dataset) or (dataset ==
"Protein")) {
225 std::vector<ged::GEDGraph::GraphID> graph_ids(env.
load_gxl_graphs(graph_dir(dataset), (train ? train_collection(dataset) : test_collection(dataset)), node_type(dataset), edge_type(dataset), irrelevant_node_attributes(dataset), irrelevant_edge_attributes(dataset)));
227 env.
init(init_type(dataset));
231 void setup_datasets(std::vector<std::string> & datasets) {
232 datasets = {
"Letter_HIGH",
"pah",
"AIDS",
"Protein",
"GREC",
"Fingerprint"};
Selects ged::Fingerprint.
void init(Options::InitType init_type=Options::InitType::EAGER_WITHOUT_SHUFFLED_COPIES)
Initializes the environment.
std::vector< GEDGraph::GraphID > load_gxl_graphs(const std::string &graph_dir, const std::string &collection_file, Options::GXLNodeEdgeType node_type=Options::GXLNodeEdgeType::LABELED, Options::GXLNodeEdgeType edge_type=Options::GXLNodeEdgeType::LABELED, const std::unordered_set< std::string > &irrelevant_node_attributes=std::unordered_set< std::string >(), const std::unordered_set< std::string > &irrelevant_edge_attributes=std::unordered_set< std::string >())
Loads graphs given in the GXL file format.
void set_edit_costs(Options::EditCosts edit_costs, std::initializer_list< double > edit_cost_constants={})
Sets the edit costs to one of the predefined edit costs.
GXLNodeEdgeType
Selects whether nodes or edges of graphs given in GXL file format are labeled or unlabeled.
Unlabeled nodes or edges.
InitType
Selects the initialization type of the environment.
Eager initialization, no shuffled graph copies are constructed.
EditCosts
Selects the edit costs.
Lazy initialization, no shuffled graph copies are constructed.
Provides the API of GEDLIB.