GEDLIB  1.0
sample.py
Go to the documentation of this file.
1 #//////////////////////////////////////////////////////////////////////////#
2 # #
3 # Copyright (C) 2018 by David B. Blumenthal #
4 # #
5 # This file is part of GEDLIB. #
6 # #
7 # GEDLIB is free software: you can redistribute it and/or modify it #
8 # under the terms of the GNU Lesser General Public License as published #
9 # by the Free Software Foundation, either version 3 of the License, or #
10 # (at your option) any later version. #
11 # #
12 # GEDLIB is distributed in the hope that it will be useful, #
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
15 # GNU Lesser General Public License for more details. #
16 # #
17 # You should have received a copy of the GNU Lesser General Public #
18 # License along with GEDLIB. If not, see <http://www.gnu.org/licenses/>. #
19 # #
20 #//////////////////////////////////////////////////////////////////////////#
21 
22 
45 '''
46 Python script that generates a random sample of given size from a given dataset.
47 '''
48 
49 import xml.etree.ElementTree as ET
50 import argparse
51 import random
52 
53 # Parse the input arguments.
54 parser = argparse.ArgumentParser(description="Generates a random sample of given size from a given dataset.")
55 parser.add_argument("dataset", help="path to existing dataset file")
56 parser.add_argument("sample", help="path to sample file to be generated by the script")
57 parser.add_argument("--exclude", help="path to existing file that list the graphs contained in the dataset which should not appear in the sample")
58 parser.add_argument("--balanced", help="generate sample with equal number of graphs per class", action="store_true")
59 group = parser.add_mutually_exclusive_group(required=True)
60 group.add_argument("--size", help="size of sample; must be greater that 0; if larger than size of and the size of the dataset", type=int)
61 group.add_argument("--size_ratio", help="size of sample divided by size of dataset; must be between 0 and 1", type=float)
62 args = parser.parse_args()
63 if args.dataset == args.sample:
64  raise Exception("dataset file equals sample file")
65 
66 # Collect excluded graphs.
67 excluded_graphs = set()
68 if args.exclude:
69  tree = ET.parse(args.exclude)
70  excluded_dataset = tree.getroot()
71  for graph in excluded_dataset:
72  excluded_graphs.add(graph.attrib["file"])
73 
74 # Collect the classes.
75 dataset = ET.parse(args.dataset).getroot()
76 classes = set()
77 graph_classes = {graph.attrib["file"] : graph.attrib["class"] for graph in dataset}
78 for graph in dataset:
79  classes.add(graph.attrib["class"])
80 num_classes = len(classes)
81 
82 # Collect the candidate graphs and group them w.r.t. their classes.
83 candidate_graphs = {cl : [] for cl in classes}
84 for graph in dataset:
85  if not graph.attrib["file"] in excluded_graphs:
86  candidate_graphs[graph.attrib["class"]].append(graph.attrib["file"])
87 candidate_sizes = {cl : len(candidate_graphs[cl]) for cl in classes}
88 total_candidate_size = sum([candidate_sizes[cl] for cl in candidate_sizes])
89 min_candidate_sizes = min([candidate_sizes[cl] for cl in candidate_sizes])
90 
91 # Determine the number of sampled graphs per class.
92 if args.size_ratio:
93  if args.size_ratio < 0 or args.size_ratio > 1:
94  raise Exception("SIZE_RATIO must be between 0 and 1")
95  if args.balanced:
96  sample_sizes = {cl : min(min_candidate_sizes, int((total_candidate_size * args.size_ratio) / num_classes)) for cl in classes}
97  else:
98  sample_sizes = {cl : int(candidate_sizes[cl] * args.size_ratio) for cl in classes}
99 else:
100  if args.size < 0:
101  raise Exception("SIZE must be greater than 0")
102  if args.size > total_candidate_size:
103  args.size = total_candidate_size
104 
105 # Sample the graphs.
106 sampled_graphs = []
107 if args.balanced:
108  sample_sizes = {cl : min(min_candidate_sizes, int(args.size / num_classes)) for cl in classes}
109  sampled_graphs = [graph for cl in classes for graph in random.sample(candidate_graphs[cl], sample_sizes[cl])]
110 else:
111  sampled_graphs = random.sample([graph for cl in classes for graph in candidate_graphs[cl]], args.size)
112 
113 # Write sampled graphs to XML file.
114 file = open(args.sample, "w")
115 file.write("<?xml version=\"1.0\"?>")
116 file.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
117 file.write("\n<GraphCollection>")
118 for graph in sampled_graphs:
119  file.write("\n\t<graph file=\"" + graph + "\" class=\"" + graph_classes[graph] + "\"/>")
120 file.write("\n</GraphCollection>")
121 file.close()