김성주

codes and final report

1 +from vocabularies import VocabType
2 +from config import Config
3 +from interactive_predict import InteractivePredictor
4 +from model_base import Code2VecModelBase
5 +
6 +
7 +def load_model_dynamically(config: Config) -> Code2VecModelBase:
8 + assert config.DL_FRAMEWORK in {'tensorflow', 'keras'}
9 + if config.DL_FRAMEWORK == 'tensorflow':
10 + from tensorflow_model import Code2VecModel
11 + elif config.DL_FRAMEWORK == 'keras':
12 + from keras_model import Code2VecModel
13 + return Code2VecModel(config)
14 +
15 +
16 +if __name__ == '__main__':
17 + config = Config(set_defaults=True, load_from_args=True, verify=True)
18 +
19 + model = load_model_dynamically(config)
20 +
21 + if config.is_training:
22 + model.train()
23 + if config.SAVE_W2V is not None:
24 + model.save_word2vec_format(config.SAVE_W2V, VocabType.Token)
25 + config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V)
26 + if config.SAVE_T2V is not None:
27 + model.save_word2vec_format(config.SAVE_T2V, VocabType.Target)
28 + config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V)
29 + if (config.is_testing and not config.is_training) or config.RELEASE:
30 + eval_results = model.evaluate()
31 + if eval_results is not None:
32 + config.log(
33 + str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
34 + if config.PREDICT:
35 + predictor = InteractivePredictor(config, model)
36 + predictor.predict()
37 + model.close_session()
1 +import re
2 +import numpy as np
3 +import tensorflow as tf
4 +from itertools import takewhile, repeat
5 +from typing import List, Optional, Tuple, Iterable
6 +from datetime import datetime
7 +from collections import OrderedDict
8 +
9 +
10 +class common:
11 +
12 + @staticmethod
13 + def normalize_word(word):
14 + stripped = re.sub(r'[^a-zA-Z]', '', word)
15 + if len(stripped) == 0:
16 + return word.lower()
17 + else:
18 + return stripped.lower()
19 +
20 + @staticmethod
21 + def _load_vocab_from_histogram(path, min_count=0, start_from=0, return_counts=False):
22 + with open(path, 'r') as file:
23 + word_to_index = {}
24 + index_to_word = {}
25 + word_to_count = {}
26 + next_index = start_from
27 + for line in file:
28 + line_values = line.rstrip().split(' ')
29 + if len(line_values) != 2:
30 + continue
31 + word = line_values[0]
32 + count = int(line_values[1])
33 + if count < min_count:
34 + continue
35 + if word in word_to_index:
36 + continue
37 + word_to_index[word] = next_index
38 + index_to_word[next_index] = word
39 + word_to_count[word] = count
40 + next_index += 1
41 + result = word_to_index, index_to_word, next_index - start_from
42 + if return_counts:
43 + result = (*result, word_to_count)
44 + return result
45 +
46 + @staticmethod
47 + def load_vocab_from_histogram(path, min_count=0, start_from=0, max_size=None, return_counts=False):
48 + if max_size is not None:
49 + word_to_index, index_to_word, next_index, word_to_count = \
50 + common._load_vocab_from_histogram(path, min_count, start_from, return_counts=True)
51 + if next_index <= max_size:
52 + results = (word_to_index, index_to_word, next_index)
53 + if return_counts:
54 + results = (*results, word_to_count)
55 + return results
56 + # Take min_count to be one plus the count of the max_size'th word
57 + min_count = sorted(word_to_count.values(), reverse=True)[max_size] + 1
58 + return common._load_vocab_from_histogram(path, min_count, start_from, return_counts)
59 +
60 + @staticmethod
61 + def load_json(json_file):
62 + data = []
63 + with open(json_file, 'r') as file:
64 + for line in file:
65 + current_program = common.process_single_json_line(line)
66 + if current_program is None:
67 + continue
68 + for element, scope in current_program.items():
69 + data.append((element, scope))
70 + return data
71 +
72 + @staticmethod
73 + def load_json_streaming(json_file):
74 + with open(json_file, 'r') as file:
75 + for line in file:
76 + current_program = common.process_single_json_line(line)
77 + if current_program is None:
78 + continue
79 + for element, scope in current_program.items():
80 + yield (element, scope)
81 +
82 + @staticmethod
83 + def save_word2vec_file(output_file, index_to_word, vocab_embedding_matrix: np.ndarray):
84 + assert len(vocab_embedding_matrix.shape) == 2
85 + vocab_size, embedding_dimension = vocab_embedding_matrix.shape
86 + output_file.write('%d %d\n' % (vocab_size, embedding_dimension))
87 + for word_idx in range(0, vocab_size):
88 + assert word_idx in index_to_word
89 + word_str = index_to_word[word_idx]
90 + output_file.write(word_str + ' ')
91 + output_file.write(' '.join(map(str, vocab_embedding_matrix[word_idx])) + '\n')
92 +
93 + @staticmethod
94 + def calculate_max_contexts(file):
95 + contexts_per_word = common.process_test_input(file)
96 + return max(
97 + [max(l, default=0) for l in [[len(contexts) for contexts in prog.values()] for prog in contexts_per_word]],
98 + default=0)
99 +
100 + @staticmethod
101 + def binary_to_string(binary_string):
102 + return binary_string.decode("utf-8")
103 +
104 + @staticmethod
105 + def binary_to_string_list(binary_string_list):
106 + return [common.binary_to_string(w) for w in binary_string_list]
107 +
108 + @staticmethod
109 + def binary_to_string_matrix(binary_string_matrix):
110 + return [common.binary_to_string_list(l) for l in binary_string_matrix]
111 +
112 + @staticmethod
113 + def load_file_lines(path):
114 + with open(path, 'r') as f:
115 + return f.read().splitlines()
116 +
117 + @staticmethod
118 + def split_to_batches(data_lines, batch_size):
119 + for x in range(0, len(data_lines), batch_size):
120 + yield data_lines[x:x + batch_size]
121 +
122 + @staticmethod
123 + def legal_method_names_checker(special_words, name):
124 + return name != special_words.OOV and re.match(r'^[a-zA-Z_|]+[a-zA-Z_]+[a-zA-Z0-9_]+$', name)
125 +
126 + @staticmethod
127 + def filter_impossible_names(special_words, top_words):
128 + result = list(filter(lambda word: common.legal_method_names_checker(special_words, word), top_words))
129 + return result
130 +
131 + @staticmethod
132 + def get_subtokens(str):
133 + return str.split('|')
134 +
135 + @staticmethod
136 + def parse_prediction_results(raw_prediction_results, unhash_dict, special_words, topk: int = 5) -> List['MethodPredictionResults']:
137 + prediction_results = []
138 + for single_method_prediction in raw_prediction_results:
139 + current_method_prediction_results = MethodPredictionResults(single_method_prediction.original_name)
140 + for i, predicted in enumerate(single_method_prediction.topk_predicted_words):
141 + if predicted == special_words.OOV:
142 + continue
143 + suggestion_subtokens = common.get_subtokens(predicted)
144 + current_method_prediction_results.append_prediction(
145 + suggestion_subtokens, single_method_prediction.topk_predicted_words_scores[i].item())
146 + topk_attention_per_context = [
147 + (key, single_method_prediction.attention_per_context[key])
148 + for key in sorted(single_method_prediction.attention_per_context,
149 + key=single_method_prediction.attention_per_context.get, reverse=True)
150 + ][:topk]
151 + for context, attention in topk_attention_per_context:
152 + token1, hashed_path, token2 = context
153 + if hashed_path in unhash_dict:
154 + unhashed_path = unhash_dict[hashed_path]
155 + current_method_prediction_results.append_attention_path(attention.item(), token1=token1,
156 + path=unhashed_path, token2=token2)
157 + prediction_results.append(current_method_prediction_results)
158 + return prediction_results
159 +
160 + @staticmethod
161 + def tf_get_first_true(bool_tensor: tf.Tensor) -> tf.Tensor:
162 + bool_tensor_as_int32 = tf.cast(bool_tensor, dtype=tf.int32)
163 + cumsum = tf.cumsum(bool_tensor_as_int32, axis=-1, exclusive=False)
164 + return tf.logical_and(tf.equal(cumsum, 1), bool_tensor)
165 +
166 + @staticmethod
167 + def count_lines_in_file(file_path: str):
168 + with open(file_path, 'rb') as f:
169 + bufgen = takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in repeat(None)))
170 + return sum(buf.count(b'\n') for buf in bufgen)
171 +
172 + @staticmethod
173 + def squeeze_single_batch_dimension_for_np_arrays(arrays):
174 + assert all(array is None or isinstance(array, np.ndarray) or isinstance(array, tf.Tensor) for array in arrays)
175 + return tuple(
176 + None if array is None else np.squeeze(array, axis=0)
177 + for array in arrays
178 + )
179 +
180 + @staticmethod
181 + def get_first_match_word_from_top_predictions(special_words, original_name, top_predicted_words) -> Optional[Tuple[int, str]]:
182 + normalized_original_name = common.normalize_word(original_name)
183 + for suggestion_idx, predicted_word in enumerate(common.filter_impossible_names(special_words, top_predicted_words)):
184 + normalized_possible_suggestion = common.normalize_word(predicted_word)
185 + if normalized_original_name == normalized_possible_suggestion:
186 + return suggestion_idx, predicted_word
187 + return None
188 +
189 + @staticmethod
190 + def now_str():
191 + return datetime.now().strftime("%Y%m%d-%H%M%S: ")
192 +
193 + @staticmethod
194 + def chunks(l, n):
195 + """Yield successive n-sized chunks from l."""
196 + for i in range(0, len(l), n):
197 + yield l[i:i + n]
198 +
199 + @staticmethod
200 + def get_unique_list(lst: Iterable) -> list:
201 + return list(OrderedDict(((item, 0) for item in lst)).keys())
202 +
203 +
204 +class MethodPredictionResults:
205 + def __init__(self, original_name):
206 + self.original_name = original_name
207 + self.predictions = list()
208 + self.attention_paths = list()
209 +
210 + def append_prediction(self, name, probability):
211 + self.predictions.append({'name': name, 'probability': probability})
212 +
213 + def append_attention_path(self, attention_score, token1, path, token2):
214 + self.attention_paths.append({'score': attention_score,
215 + 'path': path,
216 + 'token1': token1,
217 + 'token2': token2})
1 +from math import ceil
2 +from typing import Optional
3 +import logging
4 +from argparse import ArgumentParser
5 +import sys
6 +import os
7 +
8 +
9 +class Config:
10 + @classmethod
11 + def arguments_parser(cls) -> ArgumentParser:
12 + parser = ArgumentParser()
13 + parser.add_argument("-d", "--data", dest="data_path",
14 + help="path to preprocessed dataset", required=False)
15 + parser.add_argument("-te", "--test", dest="test_path",
16 + help="path to test file", metavar="FILE", required=False, default='')
17 + parser.add_argument("-s", "--save", dest="save_path",
18 + help="path to save the model file", metavar="FILE", required=False)
19 + parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
20 + help="path to save the tokens embeddings file", metavar="FILE", required=False)
21 + parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
22 + help="path to save the targets embeddings file", metavar="FILE", required=False)
23 + parser.add_argument("-l", "--load", dest="load_path",
24 + help="path to load the model from", metavar="FILE", required=False)
25 + parser.add_argument('--save_w2v', dest='save_w2v', required=False,
26 + help="save word (token) vectors in word2vec format")
27 + parser.add_argument('--save_t2v', dest='save_t2v', required=False,
28 + help="save target vectors in word2vec format")
29 + parser.add_argument('--export_code_vectors', action='store_true', required=False,
30 + help="export code vectors for the given examples")
31 + parser.add_argument('--release', action='store_true',
32 + help='if specified and loading a trained model, release the loaded model for a lower model '
33 + 'size.')
34 + parser.add_argument('--predict', action='store_true',
35 + help='execute the interactive prediction shell')
36 + parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
37 + default='tensorflow', help="deep learning framework to use.")
38 + parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
39 + help="verbose mode (should be in {0,1,2}).")
40 + parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
41 + help="path to store logs into. if not given logs are not saved to file.")
42 + parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
43 + help='use tensorboard during training')
44 + return parser
45 +
46 + def set_defaults(self):
47 + self.NUM_TRAIN_EPOCHS = 20
48 + self.SAVE_EVERY_EPOCHS = 1
49 + self.TRAIN_BATCH_SIZE = 1024
50 + self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE
51 + self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
52 + self.NUM_BATCHES_TO_LOG_PROGRESS = 100
53 + self.NUM_TRAIN_BATCHES_TO_EVALUATE = 1800
54 + self.READER_NUM_PARALLEL_BATCHES = 6
55 + self.SHUFFLE_BUFFER_SIZE = 10000
56 + self.CSV_BUFFER_SIZE = 100 * 1024 * 1024
57 + self.MAX_TO_KEEP = 10
58 +
59 + self.MAX_CONTEXTS = 200
60 + self.MAX_TOKEN_VOCAB_SIZE = 1301136
61 + self.MAX_TARGET_VOCAB_SIZE = 261245
62 + self.MAX_PATH_VOCAB_SIZE = 911417
63 + self.DEFAULT_EMBEDDINGS_SIZE = 128
64 + self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
65 + self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
66 + self.CODE_VECTOR_SIZE = self.context_vector_size
67 + self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
68 + self.DROPOUT_KEEP_RATE = 0.75
69 + self.SEPARATE_OOV_AND_PAD = False
70 +
71 + def load_from_args(self):
72 + args = self.arguments_parser().parse_args()
73 + self.PREDICT = args.predict
74 + self.MODEL_SAVE_PATH = args.save_path
75 + self.MODEL_LOAD_PATH = args.load_path
76 + self.TRAIN_DATA_PATH_PREFIX = args.data_path
77 + self.TEST_DATA_PATH = args.test_path
78 + self.RELEASE = args.release
79 + self.EXPORT_CODE_VECTORS = args.export_code_vectors
80 + self.SAVE_W2V = args.save_w2v
81 + self.SAVE_T2V = args.save_t2v
82 + self.VERBOSE_MODE = args.verbose_mode
83 + self.LOGS_PATH = args.logs_path
84 + self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
85 + self.USE_TENSORBOARD = args.use_tensorboard
86 +
87 + def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False):
88 + self.NUM_TRAIN_EPOCHS: int = 0
89 + self.SAVE_EVERY_EPOCHS: int = 0
90 + self.TRAIN_BATCH_SIZE: int = 0
91 + self.TEST_BATCH_SIZE: int = 0
92 + self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0
93 + self.NUM_BATCHES_TO_LOG_PROGRESS: int = 0
94 + self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0
95 + self.READER_NUM_PARALLEL_BATCHES: int = 0
96 + self.SHUFFLE_BUFFER_SIZE: int = 0
97 + self.CSV_BUFFER_SIZE: int = 0
98 + self.MAX_TO_KEEP: int = 0
99 +
100 + self.MAX_CONTEXTS: int = 0
101 + self.MAX_TOKEN_VOCAB_SIZE: int = 0
102 + self.MAX_TARGET_VOCAB_SIZE: int = 0
103 + self.MAX_PATH_VOCAB_SIZE: int = 0
104 + self.DEFAULT_EMBEDDINGS_SIZE: int = 0
105 + self.TOKEN_EMBEDDINGS_SIZE: int = 0
106 + self.PATH_EMBEDDINGS_SIZE: int = 0
107 + self.CODE_VECTOR_SIZE: int = 0
108 + self.TARGET_EMBEDDINGS_SIZE: int = 0
109 + self.DROPOUT_KEEP_RATE: float = 0
110 + self.SEPARATE_OOV_AND_PAD: bool = False
111 +
112 + self.PREDICT: bool = False
113 + self.MODEL_SAVE_PATH: Optional[str] = None
114 + self.MODEL_LOAD_PATH: Optional[str] = None
115 + self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None
116 + self.TEST_DATA_PATH: Optional[str] = ''
117 + self.RELEASE: bool = False
118 + self.EXPORT_CODE_VECTORS: bool = False
119 + self.SAVE_W2V: Optional[str] = None
120 + self.SAVE_T2V: Optional[str] = None
121 + self.VERBOSE_MODE: int = 0
122 + self.LOGS_PATH: Optional[str] = None
123 + self.DL_FRAMEWORK: str = 'tensorflow'
124 + self.USE_TENSORBOARD: bool = False
125 +
126 + self.NUM_TRAIN_EXAMPLES: int = 0
127 + self.NUM_TEST_EXAMPLES: int = 0
128 +
129 + self.__logger: Optional[logging.Logger] = None
130 +
131 + if set_defaults:
132 + self.set_defaults()
133 + if load_from_args:
134 + self.load_from_args()
135 + if verify:
136 + self.verify()
137 +
138 + @property
139 + def context_vector_size(self) -> int:
140 + return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE
141 +
142 + @property
143 + def is_training(self) -> bool:
144 + return bool(self.TRAIN_DATA_PATH_PREFIX)
145 +
146 + @property
147 + def is_loading(self) -> bool:
148 + return bool(self.MODEL_LOAD_PATH)
149 +
150 + @property
151 + def is_saving(self) -> bool:
152 + return bool(self.MODEL_SAVE_PATH)
153 +
154 + @property
155 + def is_testing(self) -> bool:
156 + return bool(self.TEST_DATA_PATH)
157 +
158 + @property
159 + def train_steps_per_epoch(self) -> int:
160 + return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0
161 +
162 + @property
163 + def test_steps(self) -> int:
164 + return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
165 +
166 + def data_path(self, is_evaluating: bool = False):
167 + return self.TEST_DATA_PATH if is_evaluating else self.train_data_path
168 +
169 + def batch_size(self, is_evaluating: bool = False):
170 + return self.TEST_BATCH_SIZE if is_evaluating else self.TRAIN_BATCH_SIZE # take min with NUM_TRAIN_EXAMPLES?
171 +
172 + @property
173 + def train_data_path(self) -> Optional[str]:
174 + if not self.is_training:
175 + return None
176 + return '{}.train.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
177 +
178 + @property
179 + def word_freq_dict_path(self) -> Optional[str]:
180 + if not self.is_training:
181 + return None
182 + return '{}.dict.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
183 +
184 + @classmethod
185 + def get_vocabularies_path_from_model_path(cls, model_file_path: str) -> str:
186 + vocabularies_save_file_name = "dictionaries.bin"
187 + return '/'.join(model_file_path.split('/')[:-1] + [vocabularies_save_file_name])
188 +
189 + @classmethod
190 + def get_entire_model_path(cls, model_path: str) -> str:
191 + return model_path + '__entire-model'
192 +
193 + @classmethod
194 + def get_model_weights_path(cls, model_path: str) -> str:
195 + return model_path + '__only-weights'
196 +
197 + @property
198 + def model_load_dir(self):
199 + return '/'.join(self.MODEL_LOAD_PATH.split('/')[:-1])
200 +
201 + @property
202 + def entire_model_load_path(self) -> Optional[str]:
203 + if not self.is_loading:
204 + return None
205 + return self.get_entire_model_path(self.MODEL_LOAD_PATH)
206 +
207 + @property
208 + def model_weights_load_path(self) -> Optional[str]:
209 + if not self.is_loading:
210 + return None
211 + return self.get_model_weights_path(self.MODEL_LOAD_PATH)
212 +
213 + @property
214 + def entire_model_save_path(self) -> Optional[str]:
215 + if not self.is_saving:
216 + return None
217 + return self.get_entire_model_path(self.MODEL_SAVE_PATH)
218 +
219 + @property
220 + def model_weights_save_path(self) -> Optional[str]:
221 + if not self.is_saving:
222 + return None
223 + return self.get_model_weights_path(self.MODEL_SAVE_PATH)
224 +
225 + def verify(self):
226 + if not self.is_training and not self.is_loading:
227 + raise ValueError("Must train or load a model.")
228 + if self.is_loading and not os.path.isdir(self.model_load_dir):
229 + raise ValueError("Model load dir `{model_load_dir}` does not exist.".format(
230 + model_load_dir=self.model_load_dir))
231 +
232 + def __iter__(self):
233 + for attr_name in dir(self):
234 + if attr_name.startswith("__"):
235 + continue
236 + try:
237 + attr_value = getattr(self, attr_name, None)
238 + except:
239 + attr_value = None
240 + if callable(attr_value):
241 + continue
242 + yield attr_name, attr_value
243 +
244 + def get_logger(self) -> logging.Logger:
245 + if self.__logger is None:
246 + self.__logger = logging.getLogger('code2vec')
247 + self.__logger.setLevel(logging.INFO)
248 + self.__logger.handlers = []
249 + self.__logger.propagate = 0
250 +
251 + formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
252 +
253 + if self.VERBOSE_MODE >= 1:
254 + ch = logging.StreamHandler(sys.stdout)
255 + ch.setLevel(logging.INFO)
256 + ch.setFormatter(formatter)
257 + self.__logger.addHandler(ch)
258 +
259 + if self.LOGS_PATH:
260 + fh = logging.FileHandler(self.LOGS_PATH)
261 + fh.setLevel(logging.INFO)
262 + fh.setFormatter(formatter)
263 + self.__logger.addHandler(fh)
264 +
265 + return self.__logger
266 +
267 + def log(self, msg):
268 + self.get_logger().info(msg)
1 +import traceback
2 +
3 +from common import common
4 +from py_extractor import PyExtractor
5 +
6 +SHOW_TOP_CONTEXTS = 10
7 +MAX_PATH_LENGTH = 8
8 +MAX_PATH_WIDTH = 2
9 +input_filename = 'test.c2v'
10 +
11 +
12 +class InteractivePredictor:
13 + exit_keywords = ['exit', 'quit', 'q']
14 +
15 + def __init__(self, config, model):
16 + model.predict([])
17 + self.model = model
18 + self.config = config
19 + self.path_extractor = PyExtractor(config)
20 +
21 + def predict(self):
22 + print('Starting interactive prediction...')
23 + while True:
24 + print('Modify the file: "%s" and press any key when ready, or "q" / "quit" / "exit" to exit' % input_filename)
25 + user_input = input()
26 + if user_input.lower() in self.exit_keywords:
27 + print('Exiting...')
28 + return
29 + try:
30 + predict_lines, hash_to_string_dict = self.path_extractor.extract_paths(input_filename)
31 + except ValueError as e:
32 + print(e)
33 + continue
34 + raw_prediction_results = self.model.predict(predict_lines)
35 + method_prediction_results = common.parse_prediction_results(
36 + raw_prediction_results, hash_to_string_dict,
37 + self.model.vocabs.target_vocab.special_words, topk=SHOW_TOP_CONTEXTS)
38 + for raw_prediction, method_prediction in zip(raw_prediction_results, method_prediction_results):
39 + print('Original name:\t' + method_prediction.original_name)
40 + for name_prob_pair in method_prediction.predictions:
41 + print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name']))
42 + print('Attention:')
43 + for attention_obj in method_prediction.attention_paths:
44 + print('%f\tcontext: %s,%s,%s' % (
45 + attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2']))
46 + if self.config.EXPORT_CODE_VECTORS:
47 + print('Code vector:')
48 + print(' '.join(map(str, raw_prediction.code_vector)))
1 +import numpy as np
2 +import abc
3 +import os
4 +from typing import NamedTuple, Optional, List, Dict, Tuple, Iterable
5 +
6 +from common import common
7 +from vocabularies import Code2VecVocabs, VocabType
8 +from config import Config
9 +
10 +
11 +class ModelEvaluationResults(NamedTuple):
12 + topk_acc: float
13 + subtoken_precision: float
14 + subtoken_recall: float
15 + subtoken_f1: float
16 + loss: Optional[float] = None
17 +
18 + def __str__(self):
19 + res_str = 'topk_acc: {topk_acc}, precision: {precision}, recall: {recall}, F1: {f1}'.format(
20 + topk_acc=self.topk_acc,
21 + precision=self.subtoken_precision,
22 + recall=self.subtoken_recall,
23 + f1=self.subtoken_f1)
24 + if self.loss is not None:
25 + res_str = ('loss: {}, '.format(self.loss)) + res_str
26 + return res_str
27 +
28 +
29 +class ModelPredictionResults(NamedTuple):
30 + original_name: str
31 + topk_predicted_words: np.ndarray
32 + topk_predicted_words_scores: np.ndarray
33 + attention_per_context: Dict[Tuple[str, str, str], float]
34 + code_vector: Optional[np.ndarray] = None
35 +
36 +
37 +class Code2VecModelBase(abc.ABC):
38 + def __init__(self, config: Config):
39 + self.config = config
40 + self.config.verify()
41 +
42 + self._log_creating_model()
43 +
44 + if not config.RELEASE:
45 + self._init_num_of_examples()
46 + self._log_model_configuration()
47 + self.vocabs = Code2VecVocabs(config)
48 + self.vocabs.target_vocab.get_index_to_word_lookup_table()
49 + self._load_or_create_inner_model()
50 + self._initialize()
51 +
52 + def _log_creating_model(self):
53 + self.log('')
54 + self.log('')
55 + self.log('---------------------------------------------------------------------')
56 + self.log('---------------------------------------------------------------------')
57 + self.log('---------------------- Creating code2vec model ----------------------')
58 + self.log('---------------------------------------------------------------------')
59 + self.log('---------------------------------------------------------------------')
60 +
61 + def _log_model_configuration(self):
62 + self.log('---------------------------------------------------------------------')
63 + self.log('----------------- Configuration - Hyper Parameters ------------------')
64 + longest_param_name_len = max(len(param_name) for param_name, _ in self.config)
65 + for param_name, param_val in self.config:
66 + self.log('{name: <{name_len}}{val}'.format(
67 + name=param_name, val=param_val, name_len=longest_param_name_len+2))
68 + self.log('---------------------------------------------------------------------')
69 +
70 + @property
71 + def logger(self):
72 + return self.config.get_logger()
73 +
74 + def log(self, msg):
75 + self.logger.info(msg)
76 +
77 + def _init_num_of_examples(self):
78 + self.log('Checking number of examples ...')
79 + if self.config.is_training:
80 + self.config.NUM_TRAIN_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.train_data_path)
81 + self.log(' Number of train examples: {}'.format(self.config.NUM_TRAIN_EXAMPLES))
82 + if self.config.is_testing:
83 + self.config.NUM_TEST_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.TEST_DATA_PATH)
84 + self.log(' Number of test examples: {}'.format(self.config.NUM_TEST_EXAMPLES))
85 +
86 + @staticmethod
87 + def _get_num_of_examples_for_dataset(dataset_path: str) -> int:
88 + dataset_num_examples_file_path = dataset_path + '.num_examples'
89 + if os.path.isfile(dataset_num_examples_file_path):
90 + with open(dataset_num_examples_file_path, 'r') as file:
91 + num_examples_in_dataset = int(file.readline())
92 + else:
93 + num_examples_in_dataset = common.count_lines_in_file(dataset_path)
94 + with open(dataset_num_examples_file_path, 'w') as file:
95 + file.write(str(num_examples_in_dataset))
96 + return num_examples_in_dataset
97 +
98 + def load_or_build(self):
99 + self.vocabs = Code2VecVocabs(self.config)
100 + self._load_or_create_inner_model()
101 +
102 + def save(self, model_save_path=None):
103 + if model_save_path is None:
104 + model_save_path = self.config.MODEL_SAVE_PATH
105 + model_save_dir = '/'.join(model_save_path.split('/')[:-1])
106 + if not os.path.isdir(model_save_dir):
107 + os.makedirs(model_save_dir, exist_ok=True)
108 + self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path))
109 + self._save_inner_model(model_save_path)
110 +
111 + def _write_code_vectors(self, file, code_vectors):
112 + for vec in code_vectors:
113 + file.write(' '.join(map(str, vec)) + '\n')
114 +
115 + def _get_attention_weight_per_context(
116 + self, path_source_strings: Iterable[str], path_strings: Iterable[str], path_target_strings: Iterable[str],
117 + attention_weights: Iterable[float]) -> Dict[Tuple[str, str, str], float]:
118 + attention_weights = np.squeeze(attention_weights, axis=-1) # (max_contexts, )
119 + attention_per_context: Dict[Tuple[str, str, str], float] = {}
120 +
121 + for path_source, path, path_target, weight in \
122 + zip(path_source_strings, path_strings, path_target_strings, attention_weights):
123 + string_context_triplet = (common.binary_to_string(path_source),
124 + common.binary_to_string(path),
125 + common.binary_to_string(path_target))
126 + attention_per_context[string_context_triplet] = weight
127 + return attention_per_context
128 +
129 + def close_session(self):
130 + pass
131 +
132 + @abc.abstractmethod
133 + def train(self):
134 + ...
135 +
136 + @abc.abstractmethod
137 + def evaluate(self) -> Optional[ModelEvaluationResults]:
138 + ...
139 +
140 + @abc.abstractmethod
141 + def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]:
142 + ...
143 +
144 + @abc.abstractmethod
145 + def _save_inner_model(self, path):
146 + ...
147 +
148 + def _load_or_create_inner_model(self):
149 + if self.config.is_loading:
150 + self._load_inner_model()
151 + else:
152 + self._create_inner_model()
153 +
154 + @abc.abstractmethod
155 + def _load_inner_model(self):
156 + ...
157 +
158 + def _create_inner_model(self):
159 + pass
160 +
161 + def _initialize(self):
162 + pass
163 +
164 + @abc.abstractmethod
165 + def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
166 + ...
167 +
168 + def save_word2vec_format(self, dest_save_path: str, vocab_type: VocabType):
169 + if vocab_type not in VocabType:
170 + raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.')
171 + vocab_embedding_matrix = self._get_vocab_embedding_as_np_array(vocab_type)
172 + index_to_word = self.vocabs.get(vocab_type).index_to_word
173 + with open(dest_save_path, 'w') as words_file:
174 + common.save_word2vec_file(words_file, index_to_word, vocab_embedding_matrix)
1 +import tensorflow as tf
2 +from typing import Dict, Tuple, NamedTuple, Union, Optional, Iterable
3 +from config import Config
4 +from vocabularies import Code2VecVocabs
5 +import abc
6 +from functools import reduce
7 +from enum import Enum
8 +
9 +
10 +class EstimatorAction(Enum):
11 + Train = 'train'
12 + Evaluate = 'evaluate'
13 + Predict = 'predict'
14 +
15 + @property
16 + def is_train(self):
17 + return self is EstimatorAction.Train
18 +
19 + @property
20 + def is_evaluate(self):
21 + return self is EstimatorAction.Evaluate
22 +
23 + @property
24 + def is_predict(self):
25 + return self is EstimatorAction.Predict
26 +
27 + @property
28 + def is_evaluate_or_predict(self):
29 + return self.is_evaluate or self.is_predict
30 +
31 +
32 +class ReaderInputTensors(NamedTuple):
33 + path_source_token_indices: tf.Tensor
34 + path_indices: tf.Tensor
35 + path_target_token_indices: tf.Tensor
36 + context_valid_mask: tf.Tensor
37 + target_index: Optional[tf.Tensor] = None
38 + target_string: Optional[tf.Tensor] = None
39 + path_source_token_strings: Optional[tf.Tensor] = None
40 + path_strings: Optional[tf.Tensor] = None
41 + path_target_token_strings: Optional[tf.Tensor] = None
42 +
43 +
44 +class ModelInputTensorsFormer(abc.ABC):
45 + @abc.abstractmethod
46 + def to_model_input_form(self, input_tensors: ReaderInputTensors):
47 + ...
48 +
49 + @abc.abstractmethod
50 + def from_model_input_form(self, input_row) -> ReaderInputTensors:
51 + ...
52 +
53 +
54 +class PathContextReader:
55 + def __init__(self,
56 + vocabs: Code2VecVocabs,
57 + config: Config,
58 + model_input_tensors_former: ModelInputTensorsFormer,
59 + estimator_action: EstimatorAction,
60 + repeat_endlessly: bool = False):
61 + self.vocabs = vocabs
62 + self.config = config
63 + self.model_input_tensors_former = model_input_tensors_former
64 + self.estimator_action = estimator_action
65 + self.repeat_endlessly = repeat_endlessly
66 + self.CONTEXT_PADDING = ','.join([self.vocabs.token_vocab.special_words.PAD,
67 + self.vocabs.path_vocab.special_words.PAD,
68 + self.vocabs.token_vocab.special_words.PAD])
69 + self.csv_record_defaults = [[self.vocabs.target_vocab.special_words.OOV]] + \
70 + ([[self.CONTEXT_PADDING]] * self.config.MAX_CONTEXTS)
71 +
72 + self.create_needed_vocabs_lookup_tables(self.vocabs)
73 +
74 + self._dataset: Optional[tf.data.Dataset] = None
75 +
76 + @classmethod
77 + def create_needed_vocabs_lookup_tables(cls, vocabs: Code2VecVocabs):
78 + vocabs.token_vocab.get_word_to_index_lookup_table()
79 + vocabs.path_vocab.get_word_to_index_lookup_table()
80 + vocabs.target_vocab.get_word_to_index_lookup_table()
81 +
82 + @tf.function
83 + def process_input_row(self, row_placeholder):
84 + parts = tf.io.decode_csv(
85 + row_placeholder, record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False)
86 + tensors = self._map_raw_dataset_row_to_input_tensors(*parts)
87 +
88 + tensors_expanded = ReaderInputTensors(
89 + **{name: None if tensor is None else tf.expand_dims(tensor, axis=0)
90 + for name, tensor in tensors._asdict().items()})
91 + return self.model_input_tensors_former.to_model_input_form(tensors_expanded)
92 +
93 + def process_and_iterate_input_from_data_lines(self, input_data_lines: Iterable) -> Iterable:
94 + for data_row in input_data_lines:
95 + processed_row = self.process_input_row(data_row)
96 + yield processed_row
97 +
98 + def get_dataset(self, input_data_rows: Optional = None) -> tf.data.Dataset:
99 + if self._dataset is None:
100 + self._dataset = self._create_dataset_pipeline(input_data_rows)
101 + return self._dataset
102 +
103 + def _create_dataset_pipeline(self, input_data_rows: Optional = None) -> tf.data.Dataset:
104 + if input_data_rows is None:
105 + assert not self.estimator_action.is_predict
106 + dataset = tf.data.experimental.CsvDataset(
107 + self.config.data_path(is_evaluating=self.estimator_action.is_evaluate),
108 + record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False,
109 + buffer_size=self.config.CSV_BUFFER_SIZE)
110 + else:
111 + dataset = tf.data.Dataset.from_tensor_slices(input_data_rows)
112 + dataset = dataset.map(
113 + lambda input_line: tf.io.decode_csv(
114 + tf.reshape(tf.cast(input_line, tf.string), ()),
115 + record_defaults=self.csv_record_defaults,
116 + field_delim=' ', use_quote_delim=False))
117 +
118 + if self.repeat_endlessly:
119 + dataset = dataset.repeat()
120 + if self.estimator_action.is_train:
121 + if not self.repeat_endlessly and self.config.NUM_TRAIN_EPOCHS > 1:
122 + dataset = dataset.repeat(self.config.NUM_TRAIN_EPOCHS)
123 + dataset = dataset.shuffle(self.config.SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True)
124 +
125 + dataset = dataset.map(self._map_raw_dataset_row_to_expected_model_input_form,
126 + num_parallel_calls=self.config.READER_NUM_PARALLEL_BATCHES)
127 + batch_size = self.config.batch_size(is_evaluating=self.estimator_action.is_evaluate)
128 + if self.estimator_action.is_predict:
129 + dataset = dataset.batch(1)
130 + else:
131 + dataset = dataset.filter(self._filter_input_rows)
132 + dataset = dataset.batch(batch_size)
133 +
134 + dataset = dataset.prefetch(buffer_size=40)
135 + return dataset
136 +
137 + def _filter_input_rows(self, *row_parts) -> tf.bool:
138 + row_parts = self.model_input_tensors_former.from_model_input_form(row_parts)
139 +
140 + any_word_valid_mask_per_context_part = [
141 + tf.not_equal(tf.reduce_max(row_parts.path_source_token_indices, axis=0),
142 + self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
143 + tf.not_equal(tf.reduce_max(row_parts.path_target_token_indices, axis=0),
144 + self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
145 + tf.not_equal(tf.reduce_max(row_parts.path_indices, axis=0),
146 + self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])]
147 + any_contexts_is_valid = reduce(tf.logical_or, any_word_valid_mask_per_context_part)
148 +
149 + if self.estimator_action.is_evaluate:
150 + cond = any_contexts_is_valid
151 + else:
152 + word_is_valid = tf.greater(
153 + row_parts.target_index, self.vocabs.target_vocab.word_to_index[self.vocabs.target_vocab.special_words.OOV]) # scalar
154 + cond = tf.logical_and(word_is_valid, any_contexts_is_valid)
155 +
156 + return cond
157 +
158 + def _map_raw_dataset_row_to_expected_model_input_form(self, *row_parts) -> \
159 + Tuple[Union[tf.Tensor, Tuple[tf.Tensor, ...], Dict[str, tf.Tensor]], ...]:
160 + tensors = self._map_raw_dataset_row_to_input_tensors(*row_parts)
161 + return self.model_input_tensors_former.to_model_input_form(tensors)
162 +
163 + def _map_raw_dataset_row_to_input_tensors(self, *row_parts) -> ReaderInputTensors:
164 + row_parts = list(row_parts)
165 + target_str = row_parts[0]
166 + target_index = self.vocabs.target_vocab.lookup_index(target_str)
167 +
168 + contexts_str = tf.stack(row_parts[1:(self.config.MAX_CONTEXTS + 1)], axis=0)
169 + split_contexts = tf.compat.v1.string_split(contexts_str, sep=',', skip_empty=False)
170 + sparse_split_contexts = tf.sparse.SparseTensor(
171 + indices=split_contexts.indices, values=split_contexts.values, dense_shape=[self.config.MAX_CONTEXTS, 3])
172 + dense_split_contexts = tf.reshape(
173 + tf.sparse.to_dense(sp_input=sparse_split_contexts, default_value=self.vocabs.token_vocab.special_words.PAD),
174 + shape=[self.config.MAX_CONTEXTS, 3])
175 +
176 + path_source_token_strings = tf.squeeze(
177 + tf.slice(dense_split_contexts, begin=[0, 0], size=[self.config.MAX_CONTEXTS, 1]), axis=1)
178 + path_strings = tf.squeeze(
179 + tf.slice(dense_split_contexts, begin=[0, 1], size=[self.config.MAX_CONTEXTS, 1]), axis=1)
180 + path_target_token_strings = tf.squeeze(
181 + tf.slice(dense_split_contexts, begin=[0, 2], size=[self.config.MAX_CONTEXTS, 1]), axis=1)
182 +
183 + path_source_token_indices = self.vocabs.token_vocab.lookup_index(path_source_token_strings)
184 + path_indices = self.vocabs.path_vocab.lookup_index(path_strings)
185 + path_target_token_indices = self.vocabs.token_vocab.lookup_index(path_target_token_strings)
186 +
187 + valid_word_mask_per_context_part = [
188 + tf.not_equal(path_source_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
189 + tf.not_equal(path_target_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]),
190 + tf.not_equal(path_indices, self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])]
191 + context_valid_mask = tf.cast(reduce(tf.logical_or, valid_word_mask_per_context_part), dtype=tf.float32)
192 +
193 + return ReaderInputTensors(
194 + path_source_token_indices=path_source_token_indices,
195 + path_indices=path_indices,
196 + path_target_token_indices=path_target_token_indices,
197 + context_valid_mask=context_valid_mask,
198 + target_index=target_index,
199 + target_string=target_str,
200 + path_source_token_strings=path_source_token_strings,
201 + path_strings=path_strings,
202 + path_target_token_strings=path_target_token_strings
203 + )
1 +import random
2 +from argparse import ArgumentParser
3 +import common
4 +import pickle
5 +
6 +def save_dictionaries(dataset_name, word_to_count, path_to_count, target_to_count,
7 + num_training_examples):
8 + save_dict_file_path = '{}.dict.c2v'.format(dataset_name)
9 + with open(save_dict_file_path, 'wb') as file:
10 + pickle.dump(word_to_count, file)
11 + pickle.dump(path_to_count, file)
12 + pickle.dump(target_to_count, file)
13 + pickle.dump(num_training_examples, file)
14 + print('Dictionaries saved to: {}'.format(save_dict_file_path))
15 +
16 +
17 +def process_file(file_path, data_file_role, dataset_name, word_to_count, path_to_count, max_contexts):
18 + sum_total = 0
19 + sum_sampled = 0
20 + total = 0
21 + empty = 0
22 + max_unfiltered = 0
23 + output_path = '{}.{}.c2v'.format(dataset_name, data_file_role)
24 + with open(output_path, 'w') as outfile:
25 + with open(file_path, 'r') as file:
26 + for line in file:
27 + parts = line.rstrip('\n').split(' ')
28 + target_name = parts[0]
29 + contexts = parts[1:]
30 +
31 + if len(contexts) > max_unfiltered:
32 + max_unfiltered = len(contexts)
33 + sum_total += len(contexts)
34 +
35 + if len(contexts) > max_contexts:
36 + context_parts = [c.split(',') for c in contexts]
37 + full_found_contexts = [c for i, c in enumerate(contexts)
38 + if context_full_found(context_parts[i], word_to_count, path_to_count)]
39 + partial_found_contexts = [c for i, c in enumerate(contexts)
40 + if context_partial_found(context_parts[i], word_to_count, path_to_count)
41 + and not context_full_found(context_parts[i], word_to_count,
42 + path_to_count)]
43 + if len(full_found_contexts) > max_contexts:
44 + contexts = random.sample(full_found_contexts, max_contexts)
45 + elif len(full_found_contexts) <= max_contexts \
46 + and len(full_found_contexts) + len(partial_found_contexts) > max_contexts:
47 + contexts = full_found_contexts + \
48 + random.sample(partial_found_contexts, max_contexts - len(full_found_contexts))
49 + else:
50 + contexts = full_found_contexts + partial_found_contexts
51 +
52 + if len(contexts) == 0:
53 + empty += 1
54 + continue
55 +
56 + sum_sampled += len(contexts)
57 +
58 + csv_padding = " " * (max_contexts - len(contexts))
59 + outfile.write(target_name + ' ' + " ".join(contexts) + csv_padding + '\n')
60 + total += 1
61 +
62 + print('File: ' + file_path)
63 + print('Average total contexts: ' + str(float(sum_total) / total))
64 + print('Average final (after sampling) contexts: ' + str(float(sum_sampled) / total))
65 + print('Total examples: ' + str(total))
66 + print('Empty examples: ' + str(empty))
67 + print('Max number of contexts per word: ' + str(max_unfiltered))
68 + return total
69 +
70 +
71 +def context_full_found(context_parts, word_to_count, path_to_count):
72 + return context_parts[0] in word_to_count \
73 + and context_parts[1] in path_to_count and context_parts[2] in word_to_count
74 +
75 +
76 +def context_partial_found(context_parts, word_to_count, path_to_count):
77 + return context_parts[0] in word_to_count \
78 + or context_parts[1] in path_to_count or context_parts[2] in word_to_count
79 +
80 +
81 +if __name__ == '__main__':
82 +
83 + parser = ArgumentParser()
84 + parser.add_argument("-trd", "--train_data", dest="train_data_path",
85 + help="path to training data file", required=True)
86 + parser.add_argument("-ted", "--test_data", dest="test_data_path",
87 + help="path to test data file", required=True)
88 + parser.add_argument("-vd", "--val_data", dest="val_data_path",
89 + help="path to validation data file", required=True)
90 + parser.add_argument("-mc", "--max_contexts", dest="max_contexts", default=200,
91 + help="number of max contexts to keep", required=False)
92 + parser.add_argument("-wvs", "--word_vocab_size", dest="word_vocab_size", default=1301136,
93 + help="Max number of origin word in to keep in the vocabulary", required=False)
94 + parser.add_argument("-pvs", "--path_vocab_size", dest="path_vocab_size", default=911417,
95 + help="Max number of paths to keep in the vocabulary", required=False)
96 + parser.add_argument("-tvs", "--target_vocab_size", dest="target_vocab_size", default=261245,
97 + help="Max number of target words to keep in the vocabulary", required=False)
98 + parser.add_argument("-wh", "--word_histogram", dest="word_histogram",
99 + help="word histogram file", metavar="FILE", required=True)
100 + parser.add_argument("-ph", "--path_histogram", dest="path_histogram",
101 + help="path_histogram file", metavar="FILE", required=True)
102 + parser.add_argument("-th", "--target_histogram", dest="target_histogram",
103 + help="target histogram file", metavar="FILE", required=True)
104 + parser.add_argument("-o", "--output_name", dest="output_name",
105 + help="output name - the base name for the created dataset", metavar="FILE", required=True,
106 + default='data')
107 + args = parser.parse_args()
108 +
109 + train_data_path = args.train_data_path
110 + test_data_path = args.test_data_path
111 + val_data_path = args.val_data_path
112 + word_histogram_path = args.word_histogram
113 + path_histogram_path = args.path_histogram
114 +
115 + word_histogram_data = common.common.load_vocab_from_histogram(word_histogram_path, start_from=1,
116 + max_size=int(args.word_vocab_size),
117 + return_counts=True)
118 + _, _, _, word_to_count = word_histogram_data
119 + _, _, _, path_to_count = common.common.load_vocab_from_histogram(path_histogram_path, start_from=1,
120 + max_size=int(args.path_vocab_size),
121 + return_counts=True)
122 + _, _, _, target_to_count = common.common.load_vocab_from_histogram(args.target_histogram, start_from=1,
123 + max_size=int(args.target_vocab_size),
124 + return_counts=True)
125 +
126 + num_training_examples = 0
127 + for data_file_path, data_role in zip([test_data_path, val_data_path, train_data_path], ['test', 'val', 'train']):
128 + num_examples = process_file(file_path=data_file_path, data_file_role=data_role, dataset_name=args.output_name,
129 + word_to_count=word_to_count, path_to_count=path_to_count,
130 + max_contexts=int(args.max_contexts))
131 + if data_role == 'train':
132 + num_training_examples = num_examples
133 +
134 + save_dictionaries(dataset_name=args.output_name, word_to_count=word_to_count,
135 + path_to_count=path_to_count, target_to_count=target_to_count,
136 + num_training_examples=num_training_examples)
1 +TRAIN_DIR=dataset_train
2 +VAL_DIR=dataset_val
3 +TEST_DIR=dataset_test
4 +DATASET_NAME=dataset
5 +MAX_CONTEXTS=200
6 +WORD_VOCAB_SIZE=1301136
7 +PATH_VOCAB_SIZE=911417
8 +TARGET_VOCAB_SIZE=261245
9 +NUM_THREADS=64
10 +PYTHON=python
11 +###########################################################
12 +
13 +TRAIN_DATA_PATH=data/path_contexts_train.csv
14 +VAL_DATA_PATH=data/path_contexts_val.csv
15 +TEST_DATA_PATH=data/path_contexts_test.csv
16 +
17 +TRAIN_DATA_FILE=${TRAIN_DATA_PATH}
18 +VAL_DATA_FILE=${VAL_DATA_PATH}
19 +TEST_DATA_FILE=${TEST_DATA_PATH}
20 +
21 +mkdir -p data
22 +mkdir -p data/${DATASET_NAME}
23 +
24 +TARGET_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2v
25 +ORIGIN_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2v
26 +PATH_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.path.c2v
27 +
28 +cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE}
29 +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${ORIGIN_HISTOGRAM_FILE}
30 +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${PATH_HISTOGRAM_FILE}
31 +
32 +DIR=`dirname "$0"`
33 +
34 +${PYTHON} ${DIR}/preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \
35 + --max_contexts ${MAX_CONTEXTS} --word_vocab_size ${WORD_VOCAB_SIZE} --path_vocab_size ${PATH_VOCAB_SIZE} \
36 + --target_vocab_size ${TARGET_VOCAB_SIZE} --word_histogram ${ORIGIN_HISTOGRAM_FILE} \
37 + --path_histogram ${PATH_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name data/${DATASET_NAME}/${DATASET_NAME}
38 +
39 +rm ${TARGET_HISTOGRAM_FILE} ${ORIGIN_HISTOGRAM_FILE} ${PATH_HISTOGRAM_FILE}
1 +import subprocess
2 +
3 +class PyExtractor:
4 + def __init__(self, config):
5 + self.config = config
6 +
7 + def read_file(self, input_filename):
8 + with open(input_filename, 'r') as file:
9 + return file.readlines()
10 +
11 + def extract_paths(self, path):
12 + output = self.read_file(path)
13 +
14 + if len(output) == 0:
15 + err = err.decode()
16 + raise ValueError(err)
17 + hash_to_string_dict = {}
18 + result = []
19 + for i, line in enumerate(output):
20 + parts = line.rstrip().split(' ')
21 + method_name = parts[0]
22 + current_result_line_parts = [method_name]
23 + contexts = parts[1:]
24 + for context in contexts[:self.config.MAX_CONTEXTS]:
25 + context_parts = context.split(',')
26 + context_word1 = context_parts[0]
27 + context_path = context_parts[1]
28 + context_word2 = context_parts[2]
29 + hashed_path = str(context_path)
30 + hash_to_string_dict[hashed_path] = context_path
31 + current_result_line_parts += ['%s,%s,%s' % (context_word1, hashed_path, context_word2)]
32 + space_padding = ' ' * (self.config.MAX_CONTEXTS - len(contexts))
33 + result_line = ' '.join(current_result_line_parts) + space_padding
34 + result.append(result_line)
35 + return result, hash_to_string_dict
1 +import tensorflow as tf
2 +import numpy as np
3 +import time
4 +from typing import Dict, Optional, List, Iterable
5 +from collections import Counter
6 +from functools import partial
7 +
8 +from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction
9 +from common import common
10 +from vocabularies import VocabType
11 +from config import Config
12 +from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults
13 +
14 +
15 +tf.compat.v1.disable_eager_execution()
16 +
17 +
18 +class Code2VecModel(Code2VecModelBase):
19 + def __init__(self, config: Config):
20 + self.sess = tf.compat.v1.Session()
21 + self.saver = None
22 +
23 + self.eval_reader = None
24 + self.eval_input_iterator_reset_op = None
25 + self.predict_reader = None
26 + self.MAX_BATCH_NUM = 30
27 +
28 + self.predict_placeholder = None
29 + self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors = None, None, None, None
30 + self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op = None, None, None
31 +
32 + self.vocab_type_to_tf_variable_name_mapping: Dict[VocabType, str] = {
33 + VocabType.Token: 'WORDS_VOCAB',
34 + VocabType.Target: 'TARGET_WORDS_VOCAB',
35 + VocabType.Path: 'PATHS_VOCAB'
36 + }
37 +
38 + super(Code2VecModel, self).__init__(config)
39 +
40 + def train(self):
41 + self.log('Starting training')
42 + start_time = time.time()
43 +
44 + batch_num = 0
45 + sum_loss = 0
46 + multi_batch_start_time = time.time()
47 + num_batches_to_save_and_eval = max(int(self.config.train_steps_per_epoch * self.config.SAVE_EVERY_EPOCHS), 1)
48 +
49 + train_reader = PathContextReader(vocabs=self.vocabs,
50 + model_input_tensors_former=_TFTrainModelInputTensorsFormer(),
51 + config=self.config, estimator_action=EstimatorAction.Train)
52 + input_iterator = tf.compat.v1.data.make_initializable_iterator(train_reader.get_dataset())
53 + input_iterator_reset_op = input_iterator.initializer
54 + input_tensors = input_iterator.get_next()
55 +
56 + optimizer, train_loss = self._build_tf_training_graph(input_tensors)
57 + self.saver = tf.compat.v1.train.Saver(max_to_keep=self.config.MAX_TO_KEEP)
58 +
59 + self.log('Number of trainable params: {}'.format(
60 + np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])))
61 + for variable in tf.compat.v1.trainable_variables():
62 + self.log("variable name: {} -- shape: {} -- #params: {}".format(
63 + variable.name, variable.get_shape(), np.prod(variable.get_shape().as_list())))
64 +
65 + self._initialize_session_variables()
66 +
67 + if self.config.MODEL_LOAD_PATH:
68 + self._load_inner_model(self.sess)
69 +
70 + self.sess.run(input_iterator_reset_op)
71 + time.sleep(1)
72 + self.log('Started reader...')
73 +
74 + try:
75 + while batch_num <= self.MAX_BATCH_NUM:
76 + batch_num += 1
77 +
78 + _, batch_loss = self.sess.run([optimizer, train_loss])
79 +
80 + sum_loss += batch_loss
81 + if batch_num % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0:
82 + self._trace_training(sum_loss, batch_num, multi_batch_start_time)
83 + sum_loss = 0
84 + multi_batch_start_time = time.time()
85 + if batch_num % num_batches_to_save_and_eval == 0:
86 + epoch_num = int((batch_num / num_batches_to_save_and_eval) * self.config.SAVE_EVERY_EPOCHS)
87 + model_save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num)
88 + self.save(model_save_path)
89 + self.log('Saved after %d epochs in: %s' % (epoch_num, model_save_path))
90 + evaluation_results = self.evaluate()
91 + evaluation_results_str = (str(evaluation_results).replace('topk', 'top{}'.format(
92 + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
93 + self.log('After {nr_epochs} epochs -- {evaluation_results}'.format(
94 + nr_epochs=epoch_num,
95 + evaluation_results=evaluation_results_str
96 + ))
97 + except tf.errors.OutOfRangeError:
98 + self.log('Session Exhausted during the batch training')
99 + pass # exhausted
100 +
101 + self.log('Done training')
102 +
103 + if self.config.MODEL_SAVE_PATH:
104 + self._save_inner_model(self.config.MODEL_SAVE_PATH)
105 + self.log('Model saved in file: %s' % self.config.MODEL_SAVE_PATH)
106 +
107 + elapsed = int(time.time() - start_time)
108 + self.log("Training time: %sH:%sM:%sS\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
109 +
110 + def evaluate(self) -> Optional[ModelEvaluationResults]:
111 + eval_start_time = time.time()
112 + if self.eval_reader is None:
113 + self.eval_reader = PathContextReader(vocabs=self.vocabs,
114 + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
115 + config=self.config, estimator_action=EstimatorAction.Evaluate)
116 + input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset())
117 + self.eval_input_iterator_reset_op = input_iterator.initializer
118 + input_tensors = input_iterator.get_next()
119 +
120 + self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _, \
121 + self.eval_code_vectors = self._build_tf_test_graph(input_tensors)
122 + if self.saver is None:
123 + self.saver = tf.compat.v1.train.Saver()
124 +
125 + if self.config.MODEL_LOAD_PATH and not self.config.TRAIN_DATA_PATH_PREFIX:
126 + self._initialize_session_variables()
127 + self._load_inner_model(self.sess)
128 + if self.config.RELEASE:
129 + release_name = self.config.MODEL_LOAD_PATH + '.release'
130 + self.log('Releasing model, output model: %s' % release_name)
131 + self.saver.save(self.sess, release_name)
132 + return None
133 +
134 + with open('log.txt', 'w') as log_output_file:
135 + if self.config.EXPORT_CODE_VECTORS:
136 + code_vectors_file = open(self.config.TEST_DATA_PATH + '.vectors', 'w')
137 + total_predictions = 0
138 + total_prediction_batches = 0
139 + subtokens_evaluation_metric = SubtokensEvaluationMetric(
140 + partial(common.filter_impossible_names, self.vocabs.target_vocab.special_words))
141 + topk_accuracy_evaluation_metric = TopKAccuracyEvaluationMetric(
142 + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION,
143 + partial(common.get_first_match_word_from_top_predictions, self.vocabs.target_vocab.special_words))
144 + start_time = time.time()
145 +
146 + self.sess.run(self.eval_input_iterator_reset_op)
147 +
148 + self.log('Starting evaluation')
149 +
150 + batch_num = 0
151 + try:
152 + while batch_num <= self.MAX_BATCH_NUM:
153 + batch_num += 1
154 +
155 + top_words, top_scores, original_names, code_vectors = self.sess.run(
156 + [self.eval_top_words_op, self.eval_top_values_op,
157 + self.eval_original_names_op, self.eval_code_vectors],
158 + )
159 +
160 + top_words = common.binary_to_string_matrix(top_words) # (batch, top_k)
161 + original_names = common.binary_to_string_list(original_names) # (batch,)
162 +
163 + self._log_predictions_during_evaluation(zip(original_names, top_words), log_output_file)
164 + topk_accuracy_evaluation_metric.update_batch(zip(original_names, top_words))
165 + subtokens_evaluation_metric.update_batch(zip(original_names, top_words))
166 +
167 + total_predictions += len(original_names)
168 + total_prediction_batches += 1
169 + if self.config.EXPORT_CODE_VECTORS:
170 + self._write_code_vectors(code_vectors_file, code_vectors)
171 + if total_prediction_batches % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0:
172 + elapsed = time.time() - start_time
173 + self._trace_evaluation(total_predictions, elapsed)
174 +
175 + except tf.errors.OutOfRangeError:
176 + self.log('Session Exhausted during the batch evaluating')
177 + pass
178 +
179 + self.log('Done evaluating, epoch reached')
180 + log_output_file.write(str(topk_accuracy_evaluation_metric.topk_correct_predictions) + '\n')
181 +
182 + if self.config.EXPORT_CODE_VECTORS:
183 + code_vectors_file.close()
184 +
185 + elapsed = int(time.time() - eval_start_time)
186 + self.log("Evaluation time: %sH:%sM:%sS" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
187 + return ModelEvaluationResults(
188 + topk_acc=topk_accuracy_evaluation_metric.topk_correct_predictions,
189 + subtoken_precision=subtokens_evaluation_metric.precision,
190 + subtoken_recall=subtokens_evaluation_metric.recall,
191 + subtoken_f1=subtokens_evaluation_metric.f1)
192 +
193 + def _build_tf_training_graph(self, input_tensors):
194 + input_tensors = _TFTrainModelInputTensorsFormer().from_model_input_form(input_tensors)
195 +
196 + with tf.compat.v1.variable_scope('model'):
197 + tokens_vocab = tf.compat.v1.get_variable(
198 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Token],
199 + shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE), dtype=tf.float32,
200 + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
201 + targets_vocab = tf.compat.v1.get_variable(
202 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Target],
203 + shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE), dtype=tf.float32,
204 + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
205 + attention_param = tf.compat.v1.get_variable(
206 + 'ATTENTION',
207 + shape=(self.config.CODE_VECTOR_SIZE, 1), dtype=tf.float32)
208 + paths_vocab = tf.compat.v1.get_variable(
209 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Path],
210 + shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE), dtype=tf.float32,
211 + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform"))
212 +
213 + code_vectors, _ = self._calculate_weighted_contexts(
214 + tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices,
215 + input_tensors.path_indices, input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
216 +
217 + logits = tf.matmul(code_vectors, targets_vocab, transpose_b=True)
218 + batch_size = tf.cast(tf.shape(input_tensors.target_index)[0], dtype=tf.float32)
219 + loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
220 + labels=tf.reshape(input_tensors.target_index, [-1]),
221 + logits=logits)) / batch_size
222 +
223 + optimizer = tf.compat.v1.train.AdamOptimizer().minimize(loss)
224 +
225 + return optimizer, loss
226 +
227 + def _calculate_weighted_contexts(self, tokens_vocab, paths_vocab, attention_param, source_input, path_input,
228 + target_input, valid_mask, is_evaluating=False):
229 + source_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=source_input)
230 + path_embed = tf.nn.embedding_lookup(params=paths_vocab, ids=path_input)
231 + target_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=target_input)
232 +
233 + context_embed = tf.concat([source_word_embed, path_embed, target_word_embed],
234 + axis=-1)
235 +
236 + if not is_evaluating:
237 + context_embed = tf.nn.dropout(context_embed, rate=1-self.config.DROPOUT_KEEP_RATE)
238 +
239 + flat_embed = tf.reshape(context_embed, [-1, self.config.context_vector_size])
240 + transform_param = tf.compat.v1.get_variable(
241 + 'TRANSFORM', shape=(self.config.context_vector_size, self.config.CODE_VECTOR_SIZE), dtype=tf.float32)
242 +
243 + flat_embed = tf.tanh(tf.matmul(flat_embed, transform_param))
244 +
245 + contexts_weights = tf.matmul(flat_embed, attention_param)
246 + batched_contexts_weights = tf.reshape(
247 + contexts_weights, [-1, self.config.MAX_CONTEXTS, 1])
248 + mask = tf.math.log(valid_mask)
249 + mask = tf.expand_dims(mask, axis=2)
250 + batched_contexts_weights += mask
251 + attention_weights = tf.nn.softmax(batched_contexts_weights, axis=1)
252 +
253 + batched_embed = tf.reshape(flat_embed, shape=[-1, self.config.MAX_CONTEXTS, self.config.CODE_VECTOR_SIZE])
254 + code_vectors = tf.reduce_sum(tf.multiply(batched_embed, attention_weights), axis=1)
255 +
256 + return code_vectors, attention_weights
257 +
258 + def _build_tf_test_graph(self, input_tensors, normalize_scores=False):
259 + with tf.compat.v1.variable_scope('model', reuse=self.get_should_reuse_variables()):
260 + tokens_vocab = tf.compat.v1.get_variable(
261 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Token],
262 + shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE),
263 + dtype=tf.float32, trainable=False)
264 + targets_vocab = tf.compat.v1.get_variable(
265 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Target],
266 + shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE),
267 + dtype=tf.float32, trainable=False)
268 + attention_param = tf.compat.v1.get_variable(
269 + 'ATTENTION', shape=(self.config.context_vector_size, 1),
270 + dtype=tf.float32, trainable=False)
271 + paths_vocab = tf.compat.v1.get_variable(
272 + self.vocab_type_to_tf_variable_name_mapping[VocabType.Path],
273 + shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE),
274 + dtype=tf.float32, trainable=False)
275 +
276 + targets_vocab = tf.transpose(targets_vocab)
277 +
278 + input_tensors = _TFEvaluateModelInputTensorsFormer().from_model_input_form(input_tensors)
279 +
280 + code_vectors, attention_weights = self._calculate_weighted_contexts(
281 + tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices,
282 + input_tensors.path_indices, input_tensors.path_target_token_indices,
283 + input_tensors.context_valid_mask, is_evaluating=True)
284 +
285 + scores = tf.matmul(code_vectors, targets_vocab) # (batch, target_word_vocab)
286 +
287 + topk_candidates = tf.nn.top_k(scores, k=tf.minimum(
288 + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION, self.vocabs.target_vocab.size))
289 + top_indices = topk_candidates.indices
290 + top_words = self.vocabs.target_vocab.lookup_word(top_indices)
291 + original_words = input_tensors.target_string
292 + top_scores = topk_candidates.values
293 + if normalize_scores:
294 + top_scores = tf.nn.softmax(top_scores)
295 +
296 + return top_words, top_scores, original_words, attention_weights, input_tensors.path_source_token_strings, \
297 + input_tensors.path_strings, input_tensors.path_target_token_strings, code_vectors
298 +
299 + def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]:
300 + if self.predict_reader is None:
301 + self.predict_reader = PathContextReader(vocabs=self.vocabs,
302 + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
303 + config=self.config, estimator_action=EstimatorAction.Predict)
304 + self.predict_placeholder = tf.compat.v1.placeholder(tf.string)
305 + reader_output = self.predict_reader.process_input_row(self.predict_placeholder)
306 +
307 + self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, \
308 + self.attention_weights_op, self.predict_source_string, self.predict_path_string, \
309 + self.predict_path_target_string, self.predict_code_vectors = \
310 + self._build_tf_test_graph(reader_output, normalize_scores=True)
311 +
312 + self._initialize_session_variables()
313 + self.saver = tf.compat.v1.train.Saver()
314 + self._load_inner_model(sess=self.sess)
315 +
316 + prediction_results: List[ModelPredictionResults] = []
317 + for line in predict_data_lines:
318 + batch_top_words, batch_top_scores, batch_original_name, batch_attention_weights, batch_path_source_strings,\
319 + batch_path_strings, batch_path_target_strings, batch_code_vectors = self.sess.run(
320 + [self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op,
321 + self.attention_weights_op, self.predict_source_string, self.predict_path_string,
322 + self.predict_path_target_string, self.predict_code_vectors],
323 + feed_dict={self.predict_placeholder: line})
324 +
325 + assert all(tensor.shape[0] == 1 for tensor in (batch_top_words, batch_top_scores, batch_original_name,
326 + batch_attention_weights, batch_path_source_strings,
327 + batch_path_strings, batch_path_target_strings,
328 + batch_code_vectors))
329 + top_words = np.squeeze(batch_top_words, axis=0)
330 + top_scores = np.squeeze(batch_top_scores, axis=0)
331 + original_name = batch_original_name[0]
332 + attention_weights = np.squeeze(batch_attention_weights, axis=0)
333 + path_source_strings = np.squeeze(batch_path_source_strings, axis=0)
334 + path_strings = np.squeeze(batch_path_strings, axis=0)
335 + path_target_strings = np.squeeze(batch_path_target_strings, axis=0)
336 + code_vectors = np.squeeze(batch_code_vectors, axis=0)
337 +
338 + top_words = common.binary_to_string_list(top_words)
339 + original_name = common.binary_to_string(original_name)
340 + attention_per_context = self._get_attention_weight_per_context(
341 + path_source_strings, path_strings, path_target_strings, attention_weights)
342 + prediction_results.append(ModelPredictionResults(
343 + original_name=original_name,
344 + topk_predicted_words=top_words,
345 + topk_predicted_words_scores=top_scores,
346 + attention_per_context=attention_per_context,
347 + code_vector=(code_vectors if self.config.EXPORT_CODE_VECTORS else None)
348 + ))
349 + return prediction_results
350 +
351 + def _save_inner_model(self, path: str):
352 + self.saver.save(self.sess, path)
353 +
354 + def _load_inner_model(self, sess=None):
355 + if sess is not None:
356 + self.log('Loading model weights from: ' + self.config.MODEL_LOAD_PATH)
357 + self.saver.restore(sess, self.config.MODEL_LOAD_PATH)
358 + self.log('Done loading model weights')
359 +
360 + def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
361 + assert vocab_type in VocabType
362 + vocab_tf_variable_name = self.vocab_type_to_tf_variable_name_mapping[vocab_type]
363 +
364 + if self.eval_reader is None:
365 + self.eval_reader = PathContextReader(vocabs=self.vocabs,
366 + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
367 + config=self.config, estimator_action=EstimatorAction.Evaluate)
368 + input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset())
369 + _, _, _, _, _, _, _, _ = self._build_tf_test_graph(input_iterator.get_next())
370 +
371 + if vocab_type is VocabType.Token:
372 + shape = (self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE)
373 + elif vocab_type is VocabType.Target:
374 + shape = (self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE)
375 + elif vocab_type is VocabType.Path:
376 + shape = (self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE)
377 +
378 + with tf.compat.v1.variable_scope('model', reuse=True):
379 + embeddings = tf.compat.v1.get_variable(vocab_tf_variable_name, shape=shape)
380 + self.saver = tf.compat.v1.train.Saver()
381 + self._initialize_session_variables()
382 + self._load_inner_model(self.sess)
383 + vocab_embedding_matrix = self.sess.run(embeddings)
384 + return vocab_embedding_matrix
385 +
386 + def get_should_reuse_variables(self):
387 + if self.config.TRAIN_DATA_PATH_PREFIX:
388 + return True
389 + else:
390 + return None
391 +
392 + def _log_predictions_during_evaluation(self, results, output_file):
393 + for original_name, top_predicted_words in results:
394 + found_match = common.get_first_match_word_from_top_predictions(
395 + self.vocabs.target_vocab.special_words, original_name, top_predicted_words)
396 + if found_match is not None:
397 + prediction_idx, predicted_word = found_match
398 + if prediction_idx == 0:
399 + output_file.write('Original: ' + original_name + ', predicted 1st: ' + predicted_word + '\n')
400 + else:
401 + output_file.write('\t\t predicted correctly at rank: ' + str(prediction_idx + 1) + '\n')
402 + else:
403 + output_file.write('No results for predicting: ' + original_name)
404 +
405 + def _trace_training(self, sum_loss, batch_num, multi_batch_start_time):
406 + multi_batch_elapsed = time.time() - multi_batch_start_time
407 + avg_loss = sum_loss / (self.config.NUM_BATCHES_TO_LOG_PROGRESS * self.config.TRAIN_BATCH_SIZE)
408 + throughput = self.config.TRAIN_BATCH_SIZE * self.config.NUM_BATCHES_TO_LOG_PROGRESS / \
409 + (multi_batch_elapsed if multi_batch_elapsed > 0 else 1)
410 + self.log('Average loss at batch %d: %f, \tthroughput: %d samples/sec' % (
411 + batch_num, avg_loss, throughput))
412 +
413 + def _trace_evaluation(self, total_predictions, elapsed):
414 + state_message = 'Evaluated %d examples...' % total_predictions
415 + throughput_message = "Prediction throughput: %d samples/sec" % int(
416 + total_predictions / (elapsed if elapsed > 0 else 1))
417 + self.log(state_message)
418 + self.log(throughput_message)
419 +
420 + def close_session(self):
421 + self.sess.close()
422 +
423 + def _initialize_session_variables(self):
424 + self.sess.run(tf.group(
425 + tf.compat.v1.global_variables_initializer(),
426 + tf.compat.v1.local_variables_initializer(),
427 + tf.compat.v1.tables_initializer()))
428 + self.log('Initalized variables')
429 +
430 +
431 +class SubtokensEvaluationMetric:
432 + def __init__(self, filter_impossible_names_fn):
433 + self.nr_true_positives: int = 0
434 + self.nr_false_positives: int = 0
435 + self.nr_false_negatives: int = 0
436 + self.nr_predictions: int = 0
437 + self.filter_impossible_names_fn = filter_impossible_names_fn
438 +
439 + def update_batch(self, results):
440 + for original_name, top_words in results:
441 + try:
442 + possible_names = self.filter_impossible_names_fn(top_words)
443 + prediction = possible_names[0]
444 + original_subtokens = Counter(common.get_subtokens(original_name))
445 + predicted_subtokens = Counter(common.get_subtokens(prediction))
446 + self.nr_true_positives += sum(count for element, count in predicted_subtokens.items()
447 + if element in original_subtokens)
448 + self.nr_false_positives += sum(count for element, count in predicted_subtokens.items()
449 + if element not in original_subtokens)
450 + self.nr_false_negatives += sum(count for element, count in original_subtokens.items()
451 + if element not in predicted_subtokens)
452 + self.nr_predictions += 1
453 + except Exception as e:
454 + print(e)
455 + print("List Length:", len(test))
456 + for p in test:
457 + print(p, end=' ')
458 + print('')
459 + print("Top Words:", top_words)
460 + raise
461 +
462 + @property
463 + def true_positive(self):
464 + return self.nr_true_positives / self.nr_predictions
465 +
466 + @property
467 + def false_positive(self):
468 + return self.nr_false_positives / self.nr_predictions
469 +
470 + @property
471 + def false_negative(self):
472 + return self.nr_false_negatives / self.nr_predictions
473 +
474 + @property
475 + def precision(self):
476 + return self.nr_true_positives / (self.nr_true_positives + self.nr_false_positives)
477 +
478 + @property
479 + def recall(self):
480 + return self.nr_true_positives / (self.nr_true_positives + self.nr_false_negatives)
481 +
482 + @property
483 + def f1(self):
484 + return 2 * self.precision * self.recall / (self.precision + self.recall)
485 +
486 +
487 +class TopKAccuracyEvaluationMetric:
488 + def __init__(self, top_k: int, get_first_match_word_from_top_predictions_fn):
489 + self.top_k = top_k
490 + self.nr_correct_predictions = np.zeros(self.top_k)
491 + self.nr_predictions: int = 0
492 + self.get_first_match_word_from_top_predictions_fn = get_first_match_word_from_top_predictions_fn
493 +
494 + def update_batch(self, results):
495 + for original_name, top_predicted_words in results:
496 + self.nr_predictions += 1
497 + found_match = self.get_first_match_word_from_top_predictions_fn(original_name, top_predicted_words)
498 + if found_match is not None:
499 + suggestion_idx, _ = found_match
500 + self.nr_correct_predictions[suggestion_idx:self.top_k] += 1
501 +
502 + @property
503 + def topk_correct_predictions(self):
504 + return self.nr_correct_predictions / self.nr_predictions
505 +
506 +
507 +class _TFTrainModelInputTensorsFormer(ModelInputTensorsFormer):
508 + def to_model_input_form(self, input_tensors: ReaderInputTensors):
509 + return input_tensors.target_index, input_tensors.path_source_token_indices, input_tensors.path_indices, \
510 + input_tensors.path_target_token_indices, input_tensors.context_valid_mask
511 +
512 + def from_model_input_form(self, input_row) -> ReaderInputTensors:
513 + return ReaderInputTensors(
514 + target_index=input_row[0],
515 + path_source_token_indices=input_row[1],
516 + path_indices=input_row[2],
517 + path_target_token_indices=input_row[3],
518 + context_valid_mask=input_row[4]
519 + )
520 +
521 +
522 +class _TFEvaluateModelInputTensorsFormer(ModelInputTensorsFormer):
523 + def to_model_input_form(self, input_tensors: ReaderInputTensors):
524 + return (input_tensors.target_string, input_tensors.path_source_token_indices, input_tensors.path_indices,
525 + input_tensors.path_target_token_indices, input_tensors.context_valid_mask,
526 + input_tensors.path_source_token_strings, input_tensors.path_strings,
527 + input_tensors.path_target_token_strings)
528 +
529 + def from_model_input_form(self, input_row) -> ReaderInputTensors:
530 + return ReaderInputTensors(
531 + target_string=input_row[0],
532 + path_source_token_indices=input_row[1],
533 + path_indices=input_row[2],
534 + path_target_token_indices=input_row[3],
535 + context_valid_mask=input_row[4],
536 + path_source_token_strings=input_row[5],
537 + path_strings=input_row[6],
538 + path_target_token_strings=input_row[7]
539 + )
1 +type=python
2 +dataset_name=dataset
3 +data_dir=../data/${dataset_name}
4 +data=${data_dir}/${dataset_name}
5 +test_data=${data_dir}/${dataset_name}.val.c2v
6 +model_dir=models/${type}
7 +
8 +mkdir -p ${model_dir}
9 +set -e
10 +python -u code2vec.py --data ${data} --save ${model_dir}/saved_model --test ${test_data}
1 +from itertools import chain
2 +from typing import Optional, Dict, Iterable, Set, NamedTuple
3 +import pickle
4 +import os
5 +from enum import Enum
6 +from config import Config
7 +import tensorflow as tf
8 +from argparse import Namespace
9 +
10 +from common import common
11 +
12 +
13 +class VocabType(Enum):
14 + Token = 1
15 + Target = 2
16 + Path = 3
17 +
18 +
19 +SpecialVocabWordsType = Namespace
20 +
21 +
22 +_SpecialVocabWords_OnlyOov = Namespace(
23 + OOV='<OOV>'
24 +)
25 +
26 +_SpecialVocabWords_SeparateOovPad = Namespace(
27 + PAD='<PAD>',
28 + OOV='<OOV>'
29 +)
30 +
31 +_SpecialVocabWords_JoinedOovPad = Namespace(
32 + PAD_OR_OOV='<PAD_OR_OOV>',
33 + PAD='<PAD_OR_OOV>',
34 + OOV='<PAD_OR_OOV>'
35 +)
36 +
37 +
38 +class Vocab:
39 + def __init__(self, vocab_type: VocabType, words: Iterable[str],
40 + special_words: Optional[SpecialVocabWordsType] = None):
41 + if special_words is None:
42 + special_words = Namespace()
43 +
44 + self.vocab_type = vocab_type
45 + self.word_to_index: Dict[str, int] = {}
46 + self.index_to_word: Dict[int, str] = {}
47 + self._word_to_index_lookup_table = None
48 + self._index_to_word_lookup_table = None
49 + self.special_words: SpecialVocabWordsType = special_words
50 +
51 + for index, word in enumerate(chain(common.get_unique_list(special_words.__dict__.values()), words)):
52 + self.word_to_index[word] = index
53 + self.index_to_word[index] = word
54 +
55 + self.size = len(self.word_to_index)
56 +
57 + def save_to_file(self, file):
58 + special_words_as_unique_list = common.get_unique_list(self.special_words.__dict__.values())
59 + nr_special_words = len(special_words_as_unique_list)
60 + word_to_index_wo_specials = {word: idx for word, idx in self.word_to_index.items() if idx >= nr_special_words}
61 + index_to_word_wo_specials = {idx: word for idx, word in self.index_to_word.items() if idx >= nr_special_words}
62 + size_wo_specials = self.size - nr_special_words
63 + pickle.dump(word_to_index_wo_specials, file)
64 + pickle.dump(index_to_word_wo_specials, file)
65 + pickle.dump(size_wo_specials, file)
66 +
67 + @classmethod
68 + def load_from_file(cls, vocab_type: VocabType, file, special_words: SpecialVocabWordsType) -> 'Vocab':
69 + special_words_as_unique_list = common.get_unique_list(special_words.__dict__.values())
70 +
71 + word_to_index_wo_specials = pickle.load(file)
72 + index_to_word_wo_specials = pickle.load(file)
73 + size_wo_specials = pickle.load(file)
74 + assert len(index_to_word_wo_specials) == len(word_to_index_wo_specials) == size_wo_specials
75 + min_word_idx_wo_specials = min(index_to_word_wo_specials.keys())
76 +
77 + if min_word_idx_wo_specials != len(special_words_as_unique_list):
78 + raise ValueError(
79 + "Error while attempting to load vocabulary `{vocab_type}` from file `{file_path}`. "
80 + "The stored vocabulary has minimum word index {min_word_idx}, "
81 + "while expecting minimum word index to be {nr_special_words} "
82 + "because having to use {nr_special_words} special words, which are: {special_words}. "
83 + "Please check the parameter `config.SEPARATE_OOV_AND_PAD`.".format(
84 + vocab_type=vocab_type, file_path=file.name, min_word_idx=min_word_idx_wo_specials,
85 + nr_special_words=len(special_words_as_unique_list), special_words=special_words))
86 +
87 + vocab = cls(vocab_type, [], special_words)
88 + vocab.word_to_index = {**word_to_index_wo_specials,
89 + **{word: idx for idx, word in enumerate(special_words_as_unique_list)}}
90 + vocab.index_to_word = {**index_to_word_wo_specials,
91 + **{idx: word for idx, word in enumerate(special_words_as_unique_list)}}
92 + vocab.size = size_wo_specials + len(special_words_as_unique_list)
93 + return vocab
94 +
95 + @classmethod
96 + def create_from_freq_dict(cls, vocab_type: VocabType, word_to_count: Dict[str, int], max_size: int,
97 + special_words: Optional[SpecialVocabWordsType] = None):
98 + if special_words is None:
99 + special_words = Namespace()
100 + words_sorted_by_counts = sorted(word_to_count, key=word_to_count.get, reverse=True)
101 + words_sorted_by_counts_and_limited = words_sorted_by_counts[:max_size]
102 + return cls(vocab_type, words_sorted_by_counts_and_limited, special_words)
103 +
104 + @staticmethod
105 + def _create_word_to_index_lookup_table(word_to_index: Dict[str, int], default_value: int):
106 + return tf.lookup.StaticHashTable(
107 + tf.lookup.KeyValueTensorInitializer(
108 + list(word_to_index.keys()), list(word_to_index.values()), key_dtype=tf.string, value_dtype=tf.int32),
109 + default_value=tf.constant(default_value, dtype=tf.int32))
110 +
111 + @staticmethod
112 + def _create_index_to_word_lookup_table(index_to_word: Dict[int, str], default_value: str) \
113 + -> tf.lookup.StaticHashTable:
114 + return tf.lookup.StaticHashTable(
115 + tf.lookup.KeyValueTensorInitializer(
116 + list(index_to_word.keys()), list(index_to_word.values()), key_dtype=tf.int32, value_dtype=tf.string),
117 + default_value=tf.constant(default_value, dtype=tf.string))
118 +
119 + def get_word_to_index_lookup_table(self) -> tf.lookup.StaticHashTable:
120 + if self._word_to_index_lookup_table is None:
121 + self._word_to_index_lookup_table = self._create_word_to_index_lookup_table(
122 + self.word_to_index, default_value=self.word_to_index[self.special_words.OOV])
123 + return self._word_to_index_lookup_table
124 +
125 + def get_index_to_word_lookup_table(self) -> tf.lookup.StaticHashTable:
126 + if self._index_to_word_lookup_table is None:
127 + self._index_to_word_lookup_table = self._create_index_to_word_lookup_table(
128 + self.index_to_word, default_value=self.special_words.OOV)
129 + return self._index_to_word_lookup_table
130 +
131 + def lookup_index(self, word: tf.Tensor) -> tf.Tensor:
132 + return self.get_word_to_index_lookup_table().lookup(word)
133 +
134 + def lookup_word(self, index: tf.Tensor) -> tf.Tensor:
135 + return self.get_index_to_word_lookup_table().lookup(index)
136 +
137 +
138 +WordFreqDictType = Dict[str, int]
139 +
140 +
141 +class Code2VecWordFreqDicts(NamedTuple):
142 + token_to_count: WordFreqDictType
143 + path_to_count: WordFreqDictType
144 + target_to_count: WordFreqDictType
145 +
146 +
147 +class Code2VecVocabs:
148 + def __init__(self, config: Config):
149 + self.config = config
150 + self.token_vocab: Optional[Vocab] = None
151 + self.path_vocab: Optional[Vocab] = None
152 + self.target_vocab: Optional[Vocab] = None
153 +
154 + self._already_saved_in_paths: Set[str] = set()
155 +
156 + self._load_or_create()
157 +
158 + def _load_or_create(self):
159 + assert self.config.is_training or self.config.is_loading
160 + if self.config.is_loading:
161 + vocabularies_load_path = self.config.get_vocabularies_path_from_model_path(self.config.MODEL_LOAD_PATH)
162 + if not os.path.isfile(vocabularies_load_path):
163 + raise ValueError(
164 + "Model dictionaries file is not found in model load dir. "
165 + "Expecting file `{vocabularies_load_path}`.".format(vocabularies_load_path=vocabularies_load_path))
166 + self._load_from_path(vocabularies_load_path)
167 + else:
168 + self._create_from_word_freq_dict()
169 +
170 + def _load_from_path(self, vocabularies_load_path: str):
171 + assert os.path.exists(vocabularies_load_path)
172 + self.config.log('Loading model vocabularies from: `%s` ... ' % vocabularies_load_path)
173 + with open(vocabularies_load_path, 'rb') as file:
174 + self.token_vocab = Vocab.load_from_file(
175 + VocabType.Token, file, self._get_special_words_by_vocab_type(VocabType.Token))
176 + self.target_vocab = Vocab.load_from_file(
177 + VocabType.Target, file, self._get_special_words_by_vocab_type(VocabType.Target))
178 + self.path_vocab = Vocab.load_from_file(
179 + VocabType.Path, file, self._get_special_words_by_vocab_type(VocabType.Path))
180 + self.config.log('Done loading model vocabularies.')
181 + self._already_saved_in_paths.add(vocabularies_load_path)
182 +
183 + def _create_from_word_freq_dict(self):
184 + word_freq_dict = self._load_word_freq_dict()
185 + self.config.log('Word frequencies dictionaries loaded. Now creating vocabularies.')
186 + self.token_vocab = Vocab.create_from_freq_dict(
187 + VocabType.Token, word_freq_dict.token_to_count, self.config.MAX_TOKEN_VOCAB_SIZE,
188 + special_words=self._get_special_words_by_vocab_type(VocabType.Token))
189 + self.config.log('Created token vocab. size: %d' % self.token_vocab.size)
190 + self.path_vocab = Vocab.create_from_freq_dict(
191 + VocabType.Path, word_freq_dict.path_to_count, self.config.MAX_PATH_VOCAB_SIZE,
192 + special_words=self._get_special_words_by_vocab_type(VocabType.Path))
193 + self.config.log('Created path vocab. size: %d' % self.path_vocab.size)
194 + self.target_vocab = Vocab.create_from_freq_dict(
195 + VocabType.Target, word_freq_dict.target_to_count, self.config.MAX_TARGET_VOCAB_SIZE,
196 + special_words=self._get_special_words_by_vocab_type(VocabType.Target))
197 + self.config.log('Created target vocab. size: %d' % self.target_vocab.size)
198 +
199 + def _get_special_words_by_vocab_type(self, vocab_type: VocabType) -> SpecialVocabWordsType:
200 + if not self.config.SEPARATE_OOV_AND_PAD:
201 + return _SpecialVocabWords_JoinedOovPad
202 + if vocab_type == VocabType.Target:
203 + return _SpecialVocabWords_OnlyOov
204 + return _SpecialVocabWords_SeparateOovPad
205 +
206 + def save(self, vocabularies_save_path: str):
207 + if vocabularies_save_path in self._already_saved_in_paths:
208 + return
209 + with open(vocabularies_save_path, 'wb') as file:
210 + self.token_vocab.save_to_file(file)
211 + self.target_vocab.save_to_file(file)
212 + self.path_vocab.save_to_file(file)
213 + self._already_saved_in_paths.add(vocabularies_save_path)
214 +
215 + def _load_word_freq_dict(self) -> Code2VecWordFreqDicts:
216 + assert self.config.is_training
217 + self.config.log('Loading word frequencies dictionaries from: %s ... ' % self.config.word_freq_dict_path)
218 + with open(self.config.word_freq_dict_path, 'rb') as file:
219 + token_to_count = pickle.load(file)
220 + path_to_count = pickle.load(file)
221 + target_to_count = pickle.load(file)
222 + self.config.log('Done loading word frequencies dictionaries.')
223 +
224 + return Code2VecWordFreqDicts(
225 + token_to_count=token_to_count, path_to_count=path_to_count, target_to_count=target_to_count)
226 +
227 + def get(self, vocab_type: VocabType) -> Vocab:
228 + if not isinstance(vocab_type, VocabType):
229 + raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.')
230 + if vocab_type == VocabType.Token:
231 + return self.token_vocab
232 + if vocab_type == VocabType.Target:
233 + return self.target_vocab
234 + if vocab_type == VocabType.Path:
235 + return self.path_vocab
1 +from github import Github
2 +import time
3 +import calendar
4 +
5 +DATASET_MAX = 1000
6 +
7 +class GithubCrawler:
8 + def __init__(self, token):
9 + self._token = token
10 + self._g = Github(token)
11 +
12 + def getTimeLimit(self):
13 + core_rate_limit = self._g.get_rate_limit().core
14 + reset_timestamp = calendar.timegm(core_rate_limit.reset.timetuple())
15 + sleep_time = reset_timestamp - calendar.timegm(time.gmtime()) + 1
16 + return sleep_time
17 +
18 + def search_repo(self, keywords, S = 0, E = DATASET_MAX):
19 + if type(keywords) == str:
20 + keywords = [keywords] #auto packing for one keyword
21 +
22 + query = '+'.join(keywords) + '+in:readme+in:description'
23 + result = self._g.search_repositories(query)
24 +
25 + ret = []
26 + for i in range(S, E):
27 + while True:
28 + try:
29 + r = result[i]
30 + repoName = r.owner.login+'/'+r.name
31 + print("repo found", f"[{i}]:", repoName)
32 + ret.append(repoName)
33 + break
34 + except Exception:
35 + print("Rate Limit Exceeded... Retrying", f"{[i]}", "Limit Time:", self.getTimeLimit())
36 + time.sleep(1)
37 +
38 + return ret
39 +
40 + def search_files(self, repo_url, downloadLink = False):
41 + while True:
42 + try:
43 + repo = self._g.get_repo(repo_url)
44 + break
45 + except Exception as e:
46 + if '403' in str(e):
47 + print("Rate Limit Exceeded... Retrying", f"{[i]}", "Limit Time:", self.getTimeLimit())
48 + time.sleep(1)
49 + continue
50 + print(e)
51 + return []
52 +
53 + try:
54 + contents = repo.get_contents("")
55 + except Exception: #empty repo
56 + return []
57 +
58 + files = []
59 +
60 + while contents:
61 + file_content = contents.pop(0)
62 + if file_content.type == 'dir':
63 + if 'lib' in file_content.path: #python lib is in repo (too many files)
64 + return []
65 + contents.extend(repo.get_contents(file_content.path))
66 + else:
67 + if downloadLink:
68 + files.append(file_content.download_url)
69 + else:
70 + files.append(file_content.path)
71 +
72 + return files
...\ No newline at end of file ...\ No newline at end of file
1 +import crawler
2 +import os
3 +import utils
4 +
5 +TOKEN = 'YOUR_TOKEN_HERE'
6 +DATASET_DIR = 'YOUR_PATH_HERE'
7 +REPO_PATH = 'repos.txt'
8 +
9 +utils.removeEmptyDirectories(DATASET_DIR)
10 +
11 +c = crawler.GithubCrawler(TOKEN)
12 +
13 +if not os.path.exists(REPO_PATH):
14 + repos = c.search_repo('MNIST+language:python', 1000, 2000)
15 + f = open(REPO_PATH, 'w')
16 + for r in repos:
17 + f.write(r + '\n')
18 + f.close()
19 +else:
20 + f = open(REPO_PATH, 'r')
21 + repos = f.readlines()
22 + f.close()
23 +
24 +S = 0
25 +L = len(repos)
26 +print("Found repositories:", L)
27 +
28 +for i in range(S, L):
29 + r = repos[i].strip()
30 + savename = r.replace('/', '_')
31 + print('Downloading', f'[{i}] :', savename)
32 +
33 + if os.path.exists(os.path.join(DATASET_DIR, savename)):
34 + continue
35 +
36 + files = c.search_files(r, True)
37 + files = list(filter(lambda x : utils.isformat(x, ['py', 'ipynb']), files))
38 + if len(files) > 0:
39 + utils.downloadFiles(DATASET_DIR, savename, files)
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +from requests import get
3 +
4 +def isformat(file, typenames):
5 + if type(file) != str:
6 + return False
7 +
8 + if type(typenames) == str:
9 + typenames = [typenames]
10 +
11 + dot = file.rfind('.')
12 +
13 + if dot < 0:
14 + for t in typenames:
15 + if file == t:
16 + return True
17 + return False
18 +
19 + ext = file[dot + 1 :]
20 +
21 + for t in typenames:
22 + if ext == t:
23 + return True
24 +
25 + return False
26 +
27 +def downloadFiles(root, dir, urls):
28 + if not os.path.exists(root):
29 + os.mkdir(root)
30 +
31 + path = os.path.join(root, dir)
32 +
33 + if not os.path.exists(path):
34 + os.mkdir(path)
35 + else:
36 + return
37 +
38 + for url in urls:
39 + name = os.path.basename(url)
40 + with open(os.path.join(path, name), 'wb') as f:
41 + try:
42 + response = get(url)
43 + f.write(response.content)
44 +
45 + except Exception as e:
46 + print(e)
47 + f.close()
48 + break
49 +
50 + f.close()
51 +
52 +def removeEmptyDirectories(root):
53 + cnt = 0
54 + for dir in os.listdir(root):
55 + d = os.path.join(root, dir)
56 + if len(os.listdir(d)) == 0: #empty
57 + os.rmdir(d)
58 + cnt += 1
59 +
60 + print(cnt, "empty directories removed")
...\ No newline at end of file ...\ No newline at end of file
1 +class Block:
2 + def __init__(self, type, line=''):
3 + self.blocks = list()
4 + self.code = line
5 + self.blockType = type
6 + self.indent = -1
7 +
8 + def setIndent(self, indent):
9 + self.indent = indent
10 +
11 + def addLine(self, line):
12 + if len(self.code) > 0:
13 + self.code += '\n'
14 + self.code += line
15 +
16 + def addBlock(self, block):
17 + self.blocks.append(block)
18 +
19 + def debug(self):
20 + if self.blockType != 'TYPE_NORMAL':
21 + print("Block Info:", self.blockType, self.indent)
22 + print(self.code)
23 +
24 + for block in self.blocks:
25 + if block.indent <= self.indent:
26 + raise ValueError("Invalid Indent Error Occurred: {}, INDENT {} included in {}, INDENT {}".format(block.code, block.indent, self.code, self.indent))
27 + block.debug()
28 +
29 + def __str__(self):
30 + if len(self.code) > 0:
31 + result = self.code + '\n'
32 + else:
33 + result = ''
34 +
35 + for block in self.blocks:
36 + result += block.__str__()
37 +
38 + return result
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import file_parser
3 +import random
4 +
5 +def merge_two_files(input, output): # pick two random files from input, merge and shuffle codes, print to output
6 + ori_files = [f for f in readdir(input) if is_extension(f, 'py')]
7 + files = ori_files.copy()
8 + random.shuffle(files)
9 +
10 + os.makedirs(output, exist_ok=True) # create the output directory if not exists
11 + log = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8')
12 +
13 + index = 1
14 + while len(files) > 0:
15 + if len(files) == 1:
16 + one = random.choice(ori_files)
17 + while one == files[0]: # why python doesn't have do while loop??
18 + one = random.choice(ori_files)
19 +
20 + pick = [files[0], one]
21 + else:
22 + pick = files[:2]
23 +
24 + files = files[2:]
25 +
26 + lines1 = read_file(pick[0])
27 + lines2 = read_file(pick[1])
28 +
29 + print("Merging:", pick[0], pick[1])
30 +
31 + block1 = file_parser.parse_block(lines1)
32 + block2 = file_parser.parse_block(lines2)
33 +
34 + for b in block2.blocks:
35 + block1.addBlock(b)
36 +
37 + shuffle_block(block1)
38 + write_block(os.path.join(output, '{}.py'.format(index)), block1)
39 +
40 + log.write('{}.py {} {}\n'.format(index, pick[0], pick[1]))
41 + index += 1
42 +
43 + log.close()
44 + print("Done generating Merged Dataset")
45 + print("log.txt generated in output path, for merged file info. [merge_file_name file1 file2]")
46 +
47 +
48 +'''
49 + Usage: merge_two_files('data/original', 'data/merged')
50 +'''
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import file_parser
3 +import re
4 +
5 +# obfuscator v1 uses names from other methods (shuffles method names)
6 +
7 +def detect_vars(line): # detect variables and return range tuples. except for keywords
8 + ret = list()
9 + s = 0
10 + e = 0
11 + detected = False
12 + strException = False
13 + strCh = None
14 + line += ' ' # for last separator
15 +
16 + for i in range(len(line)):
17 + c = line[i]
18 +
19 + if not strException and (c == "'" or c == '"'): # we cannot remove string first, because index gets changed
20 + strCh = c
21 + strException = True
22 + continue
23 +
24 + if strException:
25 + if c == strCh:
26 + strException = False
27 + continue
28 +
29 + if not detected and re.match('[A-Za-z_]', c):
30 + detected = True
31 + s = i
32 + continue
33 +
34 + if detected and not re.match('[A-Za-z_0-9]', c):
35 + detected = False
36 + e = i
37 + ret.append((s, e))
38 +
39 + return ret
40 +
41 +def obfuscate(lines, vars, dictionary, mapper): # obfuscate the code
42 + ret = list()
43 + ### write_file('D:/Develop/ori.py', lines)
44 +
45 + for line in lines:
46 + var_ranges = detect_vars(line)
47 + var_ranges = [(s, e) for (s, e) in var_ranges if line[s:e] in vars] # remove keywords (do not convert to words because of string exception)
48 + var_ranges.append((-1, -1)) # for out-of-range exception
49 +
50 + var_index = 0
51 + new_line = ''
52 + i = 0
53 + L = len(line)
54 +
55 + while i < L:
56 + if i == var_ranges[var_index][0]: # found var
57 + s, e = var_ranges[var_index]
58 + new_line += vars[mapper[dictionary[line[s:e]]]]
59 + i = e
60 + var_index += 1
61 + else:
62 + new_line += line[i]
63 + i += 1
64 +
65 + ret.append(new_line)
66 +
67 + ### write_file('D:/Develop/obf.py', ret)
68 + return ret
69 +
70 +def create_var_histogram(input, outPath):
71 + files = [f for f in readdir(input) if is_extension(f, 'py')]
72 + freq_dict = dict()
73 +
74 + for p in files:
75 + lines = read_file(p)
76 + lines = remove_unnecessary_comments(lines)
77 +
78 + for line in lines:
79 + file_parser.parse_keywords(line, freq_dict)
80 +
81 + hist = open(outPath, 'w', encoding='utf8')
82 + arr = sorted(freq_dict.items(), key=select_value)
83 + for i in arr:
84 + hist.write(str(i) + '\n')
85 + hist.close()
86 +
87 +def read_histogram(inputPath):
88 + lines = read_file(inputPath)
89 + ret = []
90 +
91 + for line in lines:
92 + line = line.split("'")[1]
93 + ret.append(line)
94 + return ret
95 +
96 +def obfuscate_files(input, output, var=None, threshold=4000): # obfuscate variables. Guessing variable names from keyword frequency (threshold) if variable list is not given
97 + files = [f for f in readdir(input) if is_extension(f, 'py')]
98 + freq_dict = dict()
99 + codes = list()
100 +
101 + for p in files:
102 + lines = read_file(p)
103 +
104 + lines = remove_unnecessary_comments(lines) # IMPORTANT: remove comments from lines for preprocessing
105 + codes.append((p, lines))
106 +
107 + if var == None:
108 + for line in lines:
109 + file_parser.parse_keywords(line, freq_dict)
110 +
111 +
112 + if var == None: # don't have variable list
113 + hist = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8')
114 + arr = sorted(freq_dict.items(), key=select_value)
115 + for i in arr:
116 + hist.write(str(i) + '\n')
117 + hist.close()
118 +
119 + var, _ = threshold_dict(freq_dict, threshold)
120 + var = [v[0] for v in var]
121 +
122 + dictionary = create_dictionary(var)
123 + mapper = create_mapper(len(var))
124 +
125 + ### obfuscate(codes[0][1], var, dictionary, mapper)
126 +
127 + for path, code in codes:
128 + obfuscated = obfuscate(code, var, dictionary, mapper)
129 +
130 + filepath = path.split(input)[1][1:]
131 + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists
132 + new_path = os.path.join(output, filepath)
133 + write_file(new_path, obfuscated)
134 +
135 + print("Done generating Obfuscated Dataset")
136 +
137 +
138 +'''
139 +Usage
140 +obfuscate_files('data/original', 'data/obfuscated')
141 +'''
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import file_parser
3 +import re
4 +
5 +# obfuscator v2 generate random name for methods
6 +
7 +def random_character(start=False):
8 + if start:
9 + x = random.randint(0, 52)
10 + if x == 0:
11 + return '_'
12 + elif x <= 26:
13 + return chr(65 + x - 1)
14 + else:
15 + return chr(97 + x - 27)
16 +
17 + x = random.randint(0, 62)
18 + if x == 0:
19 + return '_'
20 + elif x <= 26:
21 + return chr(65 + x - 1)
22 + elif x <= 52:
23 + return chr(97 + x - 27)
24 + else:
25 + return str(x - 53)
26 +
27 +
28 +def create_mapper_v2(L):
29 + ret = []
30 + while len(ret) < L:
31 + length = random.randint(0, 8) + 4
32 + s = random_character(True)
33 +
34 + while len(s) < length:
35 + s += random_character()
36 +
37 + if not s in ret:
38 + ret.append(s)
39 +
40 + return ret
41 +
42 +def detect_vars(line): # detect variables and return range tuples. except for keywords
43 + ret = list()
44 + s = 0
45 + e = 0
46 + detected = False
47 + strException = False
48 + strCh = None
49 + line += ' ' # for last separator
50 +
51 + for i in range(len(line)):
52 + c = line[i]
53 +
54 + if not strException and (c == "'" or c == '"'): # we cannot remove string first, because index gets changed
55 + strCh = c
56 + strException = True
57 + continue
58 +
59 + if strException:
60 + if c == strCh:
61 + strException = False
62 + continue
63 +
64 + if not detected and re.match('[A-Za-z_]', c):
65 + detected = True
66 + s = i
67 + continue
68 +
69 + if detected and not re.match('[A-Za-z_0-9]', c):
70 + detected = False
71 + e = i
72 + ret.append((s, e))
73 +
74 + return ret
75 +
76 +def obfuscate(lines, vars, dictionary, mapper): # obfuscate the code
77 + ret = list()
78 + ### write_file('D:/Develop/ori.py', lines)
79 +
80 + for line in lines:
81 + var_ranges = detect_vars(line)
82 + var_ranges = [(s, e) for (s, e) in var_ranges if line[s:e] in vars] # remove keywords (do not convert to words because of string exception)
83 + var_ranges.append((-1, -1)) # for out-of-range exception
84 +
85 + var_index = 0
86 + new_line = ''
87 + i = 0
88 + L = len(line)
89 +
90 + while i < L:
91 + if i == var_ranges[var_index][0]: # found var
92 + s, e = var_ranges[var_index]
93 + new_line += mapper[dictionary[line[s:e]]]
94 + i = e
95 + var_index += 1
96 + else:
97 + new_line += line[i]
98 + i += 1
99 +
100 + ret.append(new_line)
101 +
102 + ### write_file('D:/Develop/obf.py', ret)
103 + return ret
104 +
105 +def create_var_histogram(input, outPath):
106 + files = [f for f in readdir(input) if is_extension(f, 'py')]
107 + freq_dict = dict()
108 +
109 + for p in files:
110 + lines = read_file(p)
111 + lines = remove_unnecessary_comments(lines)
112 +
113 + for line in lines:
114 + file_parser.parse_keywords(line, freq_dict)
115 +
116 + hist = open(outPath, 'w', encoding='utf8')
117 + arr = sorted(freq_dict.items(), key=select_value)
118 + for i in arr:
119 + hist.write(str(i) + '\n')
120 + hist.close()
121 +
122 +def read_histogram(inputPath):
123 + lines = read_file(inputPath)
124 + ret = []
125 +
126 + for line in lines:
127 + line = line.split("'")[1]
128 + ret.append(line)
129 + return ret
130 +
131 +def obfuscate_files(input, output, var=None, threshold=4000): # obfuscate variables. Guessing variable names from keyword frequency (threshold) if variable list is not given
132 + files = [f for f in readdir(input) if is_extension(f, 'py')]
133 + freq_dict = dict()
134 + codes = list()
135 +
136 + for p in files:
137 + lines = read_file(p)
138 +
139 + lines = remove_unnecessary_comments(lines) # IMPORTANT: remove comments from lines for preprocessing
140 + codes.append((p, lines))
141 +
142 + if var == None:
143 + for line in lines:
144 + file_parser.parse_keywords(line, freq_dict)
145 +
146 +
147 + if var == None: # don't have variable list
148 + hist = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8')
149 + arr = sorted(freq_dict.items(), key=select_value)
150 + for i in arr:
151 + hist.write(str(i) + '\n')
152 + hist.close()
153 +
154 + var, _ = threshold_dict(freq_dict, threshold)
155 + var = [v[0] for v in var]
156 +
157 + dictionary = create_dictionary(var)
158 + mapper = create_mapper_v2(len(var))
159 +
160 + ### obfuscate(codes[0][1], var, dictionary, mapper)
161 +
162 + for path, code in codes:
163 + obfuscated = obfuscate(code, var, dictionary, mapper)
164 +
165 + filepath = path.split(input)[1][1:]
166 + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists
167 + new_path = os.path.join(output, filepath)
168 + write_file(new_path, obfuscated)
169 +
170 + print("Done generating Obfuscated Dataset")
171 +
172 +
173 +'''
174 +Usage
175 +obfuscate_files('data/original', 'data/obfuscated')
176 +'''
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import file_parser
3 +import random
4 +
5 +def refine_files(input, output):
6 + files = [f for f in readdir(input) if is_extension(f, 'py')]
7 + random.shuffle(files)
8 +
9 + for p in files:
10 + lines = read_file(p)
11 +
12 + print("Refining:", p)
13 + block = file_parser.parse_block(lines)
14 +
15 + filepath = p.split(input)[1][1:]
16 + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists
17 + path = os.path.join(output, filepath)
18 + write_block(path, block)
19 +
20 + print("Done generating Refined Dataset")
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import file_parser
3 +import random
4 +
5 +def shuffle_files(input, output): # pick random file and shuffle code order to output
6 + files = [f for f in readdir(input) if is_extension(f, 'py')]
7 + random.shuffle(files)
8 +
9 + for p in files:
10 + lines = read_file(p)
11 +
12 + print("Shuffling:", p)
13 + block = file_parser.parse_block(lines)
14 + shuffle_block(block)
15 +
16 + filepath = p.split(input)[1][1:]
17 + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists
18 + path = os.path.join(output, filepath)
19 + write_block(path, block)
20 +
21 + print("Done generating Shuffled Dataset")
22 +
23 +
24 +'''
25 +shuffle_files('data/original', 'data/shuffled')
26 +'''
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import *
2 +import re
3 +import keyword
4 +
5 +'''
6 + Test multi-line comments
7 +'''
8 +
9 +LIBRARYS = list()
10 +
11 +def parse_keywords(line, out): # out : output dictionary to sum up frequencies
12 + line = line.strip()
13 + line = remove_string(line)
14 + result = ''
15 +
16 + for c in line:
17 + if re.match('[A-Za-z_@0-9]', c):
18 + result += c
19 + else:
20 + result += ' '
21 +
22 + import_line = False
23 + prev_key = ''
24 +
25 + for key in result.split(' '):
26 + if not key or is_number(key) or key[0] in "0123456789":
27 + continue
28 +
29 + ## Exception code here
30 +
31 + if key in ['from', 'import']:
32 + import_line = True
33 +
34 + if import_line and prev_key != 'as':
35 + if not key in LIBRARYS:
36 + LIBRARYS.append(key)
37 + prev_key = key
38 + continue
39 +
40 + if key in keyword.kwlist or key in LIBRARYS or '@' in key:
41 + prev_key = key
42 + continue
43 +
44 + prev_key = key
45 +
46 + ##
47 +
48 + if not key in out:
49 + out[key] = 1
50 + else:
51 + out[key] += 1
52 +
53 +def parse_block(lines): # parse to import / def / class / normal (if, for, etc)
54 + lines = remove_unnecessary_comments(lines)
55 + root = Block('TYPE_ROOT') # main block tree node
56 + block_stack = [root]
57 + i = 0
58 + L = len(lines)
59 + # par_stack = list()
60 + # multi_string_stack = list()
61 +
62 + while i < L:
63 + line = lines[i]
64 + start_index = 0
65 + indent_count = 0
66 +
67 + while True: # count indents
68 + if line[start_index] == '\t':
69 + start_index += 1
70 + indent_count += 4
71 + elif line[start_index] == ' ':
72 + start_index += 1
73 + indent_count += 1
74 + else:
75 + break
76 +
77 + block = create_block_from_line(line)
78 + block.setIndent(indent_count)
79 +
80 + if block.blockType == 'TYPE_FACTORY': # for @factory proeprty exception
81 + i += 1
82 +
83 + temp = create_block_from_line(lines[i])
84 + if temp.blockType == 'TYPE_CLASS':
85 + block.addLine(lines[i])
86 + block.blockType = 'TYPE_CLASS'
87 + elif temp.blockType == 'TYPE_DEF':
88 + block.addLine(lines[i])
89 + block.blockType = 'TYPE_DEF'
90 + else: # unknown type exception (factory single lines, or multi line code)
91 + i -= 1 # roll back
92 +
93 + '''
94 + ### code for multi-line string/code detection, but too many exception. (most code works well due to indent parsing)
95 + line = lines[i]
96 + if detect_parenthesis(line, par_stack) or detect_multi_string(line, multi_string_stack) or detect_multi_line_code(lines[i]): # code is not ended in a single line
97 + i += 1
98 + while detect_parenthesis(lines[i], par_stack) or detect_multi_string(lines[i], multi_string_stack) or detect_multi_line_code(lines[i]):
99 + block.addLine(lines[i])
100 + i += 1
101 +
102 + block.addLine(lines[i])
103 + '''
104 +
105 + if indent_count == block_stack[-1].indent: # same indent -> change the block
106 + block_stack.pop()
107 + block_stack[-1].addBlock(block)
108 + block_stack.append(block)
109 + elif indent_count > block_stack[-1].indent: # block included in previous block
110 + block_stack[-1].addBlock(block)
111 + block_stack.append(block)
112 + else: # block ended
113 + while indent_count <= block_stack[-1].indent:
114 + block_stack.pop()
115 + block_stack[-1].addBlock(block)
116 + block_stack.append(block)
117 + i += 1
118 +
119 + return root
120 +
121 +
122 +"""
123 + Usage
124 +
125 + path = 'data/test.py'
126 + f = open(path, 'r')
127 + lines = f.readlines()
128 + f.close()
129 +
130 +
131 + block = parse_block(lines)
132 + block.debug()
133 +
134 +
135 + '''
136 + keywords = dict()
137 + parse_keywords(lines, keywords)
138 +
139 + for k, v in keywords.items():
140 + print(k,':',v)
141 +
142 + a, b = threshold_dict(keywords, 3)
143 +
144 + print(a)
145 + print(b)
146 + '''
147 +"""
148 +
149 +'''
150 +d = dict()
151 +parse_keywords('from test.library import a as x, b as y', d)
152 +print(d)
153 +'''
...\ No newline at end of file ...\ No newline at end of file
1 +from utils import remove_string
2 +import utils
3 +import data_merger
4 +import data_refiner
5 +import data_shuffler
6 +import file_parser
7 +import data_obfuscator_v2
8 +
9 +if __name__ == '__main__':
10 + input_path = 'data/original'
11 + data_refiner.refine_files(input_path, 'data/refined')
12 + data_merger.merge_two_files(input_path, 'data/merged')
13 + data_shuffler.shuffle_files(input_path, 'data/shuffled')
14 + vars = data_obfuscator_v2.read_histogram('data/histogram_v1.txt')
15 + data_obfuscator_v2.obfuscate_files(input_path, 'data/obfuscated2', vars)
16 +
17 + # utils.write_file('data/keyword_examples.txt', utils.search_keyword(input_path, 'rand'))
18 + # data_obfuscator.create_var_histogram(input_path, 'data/histogram.txt')
1 +from block import Block
2 +import bisect
3 +import os
4 +import re
5 +import random
6 +
7 +TYPE_CLASS = ['class']
8 +TYPE_DEF = ['def']
9 +TYPE_IMPORT = ['from', 'import']
10 +TYPE_CONDITOIN = ['if', 'elif', 'else', 'for', 'while', 'with']
11 +multi_line_comments = ["'''", '"""']
12 +
13 +def select_value(x):
14 + return x[1]
15 +
16 +def threshold_dict(d, val): # split dict in two by thesholding value
17 + arr = sorted(d.items(), key=select_value)
18 + index = bisect.bisect_left([r[1] for r in arr], val)
19 + return arr[:index], arr[index:]
20 +
21 +def is_number(s):
22 + if s[0] == '-':
23 + s = s[1:]
24 + return s.replace('.','',1).isdigit()
25 +
26 +def is_extension(f, ext):
27 + return os.path.splitext(f)[1][1:] == ext
28 +
29 +def _readdir_r(dirpath): # readdir for recursive
30 + ret = []
31 + for f in os.listdir(dirpath):
32 + ret.append(os.path.join(dirpath, f))
33 +
34 + return ret
35 +
36 +def readdir(path): # read files from the directory
37 + pathList = [path]
38 + result = []
39 + i = 0
40 +
41 + while i < len(pathList):
42 + f = pathList[i]
43 + if os.path.isdir(f):
44 + pathList += _readdir_r(f)
45 + else:
46 + result.append(f)
47 +
48 + i += 1
49 +
50 + return result
51 +
52 +def remove_string(line):
53 + strIn = False
54 + strCh = None
55 + result = ''
56 + i = 0
57 + L = len(line)
58 +
59 + while i < L:
60 + if i + 3 < L:
61 + if line[i:i+3] in multi_line_comments:
62 + if not strIn:
63 + strIn = True
64 + strCh = line[i:i+3]
65 + elif line[i:i+3] == strCh:
66 + strIn = False
67 +
68 + i += 2
69 + continue
70 +
71 + c = line[i]
72 + i += 1
73 +
74 + if c == '\'' or c == '\"':
75 + if not strIn:
76 + strIn = True
77 + strCh = c
78 + elif c == strCh:
79 + strIn = False
80 + continue
81 +
82 + if strIn:
83 + continue
84 +
85 + result += c
86 +
87 + return result
88 +
89 +def using_multi_string(line, index):
90 + line = line.strip()
91 + for comment in multi_line_comments:
92 + if line.find(comment, index) > 0:
93 + return True
94 + return False
95 +
96 +def remove_unnecessary_comments(lines):
97 + # Warning : cannot detect all multi-line comments, because it exactly is multi-line string.
98 +
99 + #TODO: multi line string parser will not work well when using strings (and comments, also) more than one.
100 + # ex) a = ''' d ''' + '''
101 + # abc ''' + '''
102 + # x'''
103 +
104 + result = []
105 + multi_line = False
106 + multi_string = False
107 + strCh = None
108 +
109 + for line in lines:
110 + find_str_index = 0
111 + if multi_string:
112 + if strCh in line:
113 + find_str_index = line.find(strCh) + 3
114 + multi_string = False
115 + strCh = None
116 +
117 + result.append(line)
118 + continue
119 +
120 + if multi_line: # parsing multi-line comments
121 + if strCh in line:
122 + multi_line = False
123 + strCh = None
124 + continue
125 +
126 + if using_multi_string(line, find_str_index):
127 + i1 = line.find(multi_line_comments[0])
128 + i2 = line.find(multi_line_comments[1])
129 +
130 + if i1 < 0:
131 + i1 = len(line) + 1
132 + if i2 < 0:
133 + i2 = len(line) + 1
134 +
135 + if i1 < i2:
136 + strCh = multi_line_comments[0]
137 + else:
138 + strCh = multi_line_comments[1]
139 +
140 + result.append(line)
141 + if line.count(strCh) % 2 != 0:
142 + multi_string = True
143 + continue
144 +
145 + code = line.strip()
146 +
147 + if code[:3] in multi_line_comments: # detect in-out of multi-line comments
148 + if code.count(code[:3]) % 2 != 0: # comment count in odd numbers (start or end of comment is in the line)
149 + multi_line = True
150 + strCh = code[:3]
151 + continue
152 +
153 + comment_index = line.find('#')
154 + if comment_index >= 0: # one line comment found
155 + line = line[:comment_index]
156 + line = line.rstrip() # remove rightmost spaces
157 +
158 + if len(line) == 0: # no code in this line
159 + continue
160 +
161 + result.append(line) # add to results
162 +
163 + return result
164 +
165 +def create_block_from_line(line):
166 + _line = remove_string(line)
167 + _line = _line.strip()
168 +
169 + if '@' in _line:
170 + return Block('TYPE_FACTORY', line)
171 +
172 + keywords = _line.split(' ')
173 +
174 + for key in keywords:
175 + if key in TYPE_IMPORT:
176 + return Block('TYPE_IMPORT', line)
177 +
178 + if key in TYPE_CLASS:
179 + return Block('TYPE_CLASS', line)
180 +
181 + if key in TYPE_DEF:
182 + return Block('TYPE_DEF', line)
183 +
184 + if key in TYPE_CONDITOIN:
185 + return Block('TYPE_CONDITION', line)
186 +
187 + return Block('TYPE_NORMAL', line)
188 +
189 +def create_dictionary(arr): # create index dictionary for str array
190 + ret = dict()
191 +
192 + key = 0
193 + for name in arr:
194 + ret[name] = key
195 + key += 1
196 +
197 + return ret
198 +
199 +def create_mapper(L): # create mapping array to match each index in range L
200 + arr = list(range(L))
201 + random.shuffle(arr)
202 + ret = arr.copy()
203 +
204 + for i in range(L):
205 + ret[i] = arr[i]
206 +
207 + return ret
208 +
209 +def read_file(path):
210 + f = open(path, 'r', encoding='utf8')
211 + ret = f.readlines()
212 + f.close()
213 + return ret
214 +
215 +def write_file(path, lines):
216 + f = open(path, 'w', encoding='utf8')
217 +
218 + for line in lines:
219 + if '\n' in line:
220 + f.write(line)
221 + else:
222 + f.write(line + '\n')
223 + f.close()
224 +
225 +def write_block(path, block):
226 + f = open(path, 'w', encoding='utf8')
227 + f.write(str(block))
228 + f.close()
229 +
230 +def shuffle_block(block):
231 + if block.blockType != 'TYPE_CLASS' and block.blockType != 'TYPE_ROOT':
232 + return
233 +
234 + for b in block.blocks:
235 + shuffle_block(b)
236 +
237 + random.shuffle(block.blocks)
238 +
239 +def detect_multi_string(line, stack):
240 + L = len(line)
241 +
242 + for i in range(L):
243 + if i + 3 > L:
244 + break
245 +
246 + s = line[i:i+3]
247 + if s in multi_line_comments:
248 + if len(stack) > 0 and stack[-1] == s:
249 + stack.pop()
250 + elif len(stack) == 0:
251 + stack.append(s)
252 + return len(stack) > 0
253 +
254 +def detect_parenthesis(line, stack):
255 + line = remove_string(line)
256 +
257 + for c in line:
258 + if c == '(':
259 + stack.append(1)
260 + elif c == ')':
261 + stack.pop()
262 +
263 + if len(stack) > 0:
264 + print(line)
265 + return len(stack) > 0
266 +
267 +def detect_multi_line_code(line):
268 + line = line.rstrip()
269 + return len(line) > 0 and line[-1] == '\\'
270 +
271 +def search_keyword(path, keyword, fast_detect=False): # detect just key string is included in the line if fast_detect is True
272 + files = [f for f in readdir(path) if is_extension(f, 'py')]
273 + result = list()
274 +
275 + for p in files:
276 + lines = read_file(p)
277 + lines = remove_unnecessary_comments(lines)
278 +
279 + for line in lines:
280 +
281 + if fast_detect:
282 + if keyword in line:
283 + result.append(line)
284 + continue
285 +
286 + x = ''
287 + for c in line:
288 + if re.match('[A-Za-z_@0-9]', c):
289 + x += c
290 + else:
291 + x += ' '
292 +
293 + keywords = x.split(' ')
294 + if keyword in keywords:
295 + result.append(line)
296 +
297 + return result
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +
3 +MAX_SEQ_LENGTH = 384
4 +BATCH_SIZE = 64
5 +EPOCHS = 50
6 +
7 +BASE_OUTPUT = "output/siamese"
8 +
9 +DATASET_PATH = "data/pair_dataset.npz" #path for generated pair dataset
10 +VECTOR_PATH = "data/vectors.npz" #path for feature vectors from code dataset
11 +EMBEDDING_PATH = "data/embedding.npz" #path for embedding vector
12 +MODEL_PATH = os.path.sep.join([BASE_OUTPUT, "siamese_model"])
13 +PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
...\ No newline at end of file ...\ No newline at end of file
1 +import numpy as np
2 +import random
3 +import pandas as pd
4 +from keras.preprocessing.text import Tokenizer
5 +from utils import *
6 +
7 +def save_dataset(path, pairData, pairLabels, compressed=True):
8 + if compressed:
9 + np.savez_compressed(path, pairData=pairData, pairLabels=pairLabels)
10 + else:
11 + np.savez(path, pairData=pairData, pairLabels=pairLabels)
12 +
13 +def load_dataset(path):
14 + data = np.load(path, allow_pickle=True)
15 + return (data['pairData'], data['pairLabels'])
16 +
17 +def make_dataset_small(path): # couldn't make dataser for shuffled/merged/obfuscated, as memory run out.
18 + vecs = np.load(path, allow_pickle=True)['vecs']
19 +
20 + pairData = []
21 + pairLabels = [] # 1 for plagiarism
22 +
23 + # original pair
24 + for i in range(len(vecs)):
25 + currentData = vecs[i]
26 +
27 + pairData.append([currentData, currentData])
28 + pairLabels.append([1])
29 +
30 + j = i
31 + while j == i:
32 + j = random.randint(0, len(vecs) - 1)
33 +
34 + pairData.append([currentData, vecs[j]])
35 + pairLabels.append([0])
36 +
37 + return (np.array(pairData), np.array(pairLabels))
38 +
39 +def load_embedding(path):
40 + data = np.load(path, allow_pickle=True)
41 + return (data['vocab_size'], data['embedding_matrix'])
...\ No newline at end of file ...\ No newline at end of file
1 +import re
2 +from utils import remove_string
3 +
4 +def parse_keywords(line):
5 + line = line.strip()
6 + line = remove_string(line)
7 + result = ''
8 +
9 + for c in line:
10 + if re.match('[A-Za-z_@0-9]', c):
11 + result += c
12 + else:
13 + result += ' '
14 +
15 + return result.split(' ')
...\ No newline at end of file ...\ No newline at end of file
1 +from tensorflow.python.keras import backend as K
2 +from tensorflow.keras.models import Model
3 +from tensorflow.keras.layers import Input
4 +from tensorflow.keras.layers import Layer
5 +from tensorflow.keras.layers import LSTM
6 +from tensorflow.keras.layers import Embedding
7 +from tensorflow.python.keras.layers.wrappers import Bidirectional
8 +from tensorflow.keras.models import Sequential
9 +from tensorflow.keras.optimizers import Adam
10 +
11 +class ManDist(Layer):
12 + def __init__(self, **kwargs):
13 + self.result = None
14 + super(ManDist, self).__init__(**kwargs)
15 +
16 + def build(self, input_shape):
17 + super(ManDist, self).build(input_shape)
18 +
19 + def call(self, x, **kwargs):
20 + self.result = K.exp(-K.sum(K.abs(x[0] - x[1]), axis=1, keepdims=True))
21 + return self.result
22 +
23 + def compute_output_shape(self):
24 + return K.int_shape(self.result)
25 +
26 +def build_siamese_model(embedding_matrix, embeddingDim, max_sequence_length=384, number_lstm_units=50, rate_drop_lstm=0.01):
27 +
28 + x = Sequential()
29 + x.add(Embedding(len(embedding_matrix), embeddingDim, weights=[embedding_matrix], input_shape=(max_sequence_length,), trainable=False))
30 + x.add(LSTM(number_lstm_units, dropout=rate_drop_lstm, return_sequences=True, activation='softmax'))
31 +
32 + input_1 = Input(shape=(max_sequence_length,), dtype='int32')
33 + input_2 = Input(shape=(max_sequence_length,), dtype='int32')
34 +
35 + distance = ManDist()([x(input_1), x(input_2)])
36 + model = Model(inputs=[input_1, input_2], outputs=[distance])
37 + model.compile(loss='mean_squared_error', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
38 +
39 + return model
...\ No newline at end of file ...\ No newline at end of file
1 +import config
2 +from tensorflow.keras.models import load_model
3 +from gensim.models import KeyedVectors
4 +from file_parser import parse_keywords
5 +import tensorflow as tf
6 +from utils import *
7 +import random
8 +import numpy as np
9 +
10 +def avg_feature_vector(text, model, num_features, index2word_set):
11 + words = parse_keywords(text)
12 + feature_vec = np.zeros((num_features,), dtype='float32')
13 + n_words = 0
14 + for word in words:
15 + if word in index2word_set:
16 + n_words += 1
17 + feature_vec = np.add(feature_vec, model[word])
18 + if (n_words > 0):
19 + feature_vec = np.divide(feature_vec, n_words)
20 + return feature_vec
21 +
22 +def compare(c2v_model, model, dir1, dir2):
23 + files = [f for f in readdir(dir1) if is_extension(f, 'py')]
24 +
25 + plt.ylabel('cos_sim')
26 + m = 10
27 + Mx = 0
28 + idx = 0
29 + L = len(files)
30 + data = []
31 + index2word_set = set(c2v_model.index_to_key)
32 +
33 + for f in files:
34 + print(idx,"/",L)
35 + f2 = dir2 + f.split(dir1)[1]
36 +
37 + text1 = readAll(f)
38 + text2 = readAll(f2)
39 +
40 + input1 = avg_feature_vector(text1, c2v_model, 384, index2word_set)
41 + input2 = avg_feature_vector(text2, c2v_model, 384, index2word_set)
42 +
43 + data.append([[input1], [input2]])
44 + idx += 1
45 +
46 + result = model.predict(data)
47 + print(result)
48 +
49 +vectors_text_path = 'data/targets.txt'
50 +c2v_model = KeyedVectors.load_word2vec_format(vectors_text_path, binary=False)
51 +model = load_model(config.MODEL_PATH)
52 +
53 +# Usage
54 +# compare(c2v_model, model, 'data/refined', 'data/shuffled')
...\ No newline at end of file ...\ No newline at end of file
1 +import config
2 +from dataset import load_dataset
3 +from tensorflow.keras.models import load_model
4 +import tensorflow as tf
5 +
6 +pairData, pairLabels = load_dataset(config.DATASET_PATH)
7 +print("Loaded Dataset")
8 +
9 +X1 = pairData[:, 0].tolist()
10 +X2 = pairData[:, 1].tolist()
11 +Label = pairLabels[:].tolist()
12 +
13 +X1 = tf.convert_to_tensor(X1)
14 +X2 = tf.convert_to_tensor(X2)
15 +Label = tf.convert_to_tensor(Label)
16 +
17 +model = load_model(config.MODEL_PATH)
18 +
19 +result = model.evaluate([X1, X2], Label, batch_size=64)
20 +print("test loss, test acc:", result)
...\ No newline at end of file ...\ No newline at end of file
1 +from tokenize import Token
2 +from utils import plot_training
3 +import config
4 +import os
5 +import numpy as np
6 +import random
7 +import tensorflow as tf
8 +from dataset import load_dataset, load_embedding, make_dataset_small_v2, save_dataset
9 +from model import build_siamese_model
10 +from tensorflow.keras.models import load_model
11 +from tensorflow.keras.callbacks import Callback
12 +
13 +# load dataset
14 +if os.path.exists(config.DATASET_PATH):
15 + pairData, pairLabels = load_dataset(config.DATASET_PATH)
16 + print("Loaded Dataset")
17 +else:
18 + print("Generating Dataset...")
19 + pairData, pairLabels = make_dataset_small(config.VECTOR_PATH)
20 + save_dataset(config.DATASET_PATH, pairData, pairLabels)
21 + print("Saved Dataset")
22 +
23 +# build model
24 +
25 +if not os.path.exists(config.MODEL_PATH):
26 + print("Loading Embedding Vectors...")
27 + vocab_size, embedding_matrix = load_embedding(config.EMBEDDING_PATH)
28 + print("Building Models...")
29 + model = build_siamese_model(embedding_matrix, 384)
30 +else:
31 + model = load_model(config.MODEL_PATH)
32 +
33 +# train model
34 +
35 +X1 = pairData[:, 0].tolist()
36 +X2 = pairData[:, 1].tolist()
37 +Label = pairLabels[:].tolist()
38 +
39 +X1 = tf.convert_to_tensor(X1)
40 +X2 = tf.convert_to_tensor(X2)
41 +Label = tf.convert_to_tensor(Label)
42 +
43 +Length = int(len(X1) * 0.7)
44 +trainX1, testX1 = X1[:Length], X1[-Length:]
45 +trainX2, testX2 = X2[:Length], X2[-Length:]
46 +trainY, testY = Label[:Length], Label[-Length:]
47 +
48 +print("Training Model...")
49 +
50 +history = model.fit([trainX1, trainX2], trainY, batch_size=config.BATCH_SIZE, epochs=config.EPOCHS,
51 + validation_data=([testX1, testX2], testY))
52 +
53 +
54 +print("Saving Model...")
55 +model.save(config.MODEL_PATH)
56 +print("Saved Model")
57 +
58 +plot_training(history, config.PLOT_PATH)
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import re
3 +import matplotlib.pyplot as plt
4 +
5 +multi_line_comments = ["'''", '"""']
6 +
7 +def remove_string(line):
8 + strIn = False
9 + strCh = None
10 + result = ''
11 + i = 0
12 + L = len(line)
13 +
14 + while i < L:
15 + if i + 3 < L:
16 + if line[i:i+3] in multi_line_comments:
17 + if not strIn:
18 + strIn = True
19 + strCh = line[i:i+3]
20 + elif line[i:i+3] == strCh:
21 + strIn = False
22 +
23 + i += 2
24 + continue
25 +
26 + c = line[i]
27 + i += 1
28 +
29 + if c == '\'' or c == '\"':
30 + if not strIn:
31 + strIn = True
32 + strCh = c
33 + elif c == strCh:
34 + strIn = False
35 + continue
36 +
37 + if strIn:
38 + continue
39 +
40 + result += c
41 +
42 + return result
43 +
44 +def is_extension(f, ext):
45 + return os.path.splitext(f)[1][1:] == ext
46 +
47 +def _readdir_r(dirpath): # readdir for recursive
48 + ret = []
49 + for f in os.listdir(dirpath):
50 + ret.append(os.path.join(dirpath, f))
51 +
52 + return ret
53 +
54 +def readdir(path): # read files from the directory
55 + pathList = [path]
56 + result = []
57 + i = 0
58 +
59 + while i < len(pathList):
60 + f = pathList[i]
61 + if os.path.isdir(f):
62 + pathList += _readdir_r(f)
63 + else:
64 + result.append(f)
65 +
66 + i += 1
67 +
68 + return result
69 +
70 +def readAll(path):
71 + f = open(path, 'r', encoding='utf8')
72 + ret = f.read()
73 + f.close()
74 + return ret
75 +
76 +def readLines(path):
77 + f = open(path, 'r', encoding='utf8')
78 + ret = f.readlines()
79 + f.close()
80 + return ret
81 +
82 +def plot_training(H, plotPath):
83 + plt.style.use("ggplot")
84 + plt.figure()
85 + plt.plot(H.history["loss"], label="train_loss")
86 + plt.plot(H.history["val_loss"], label="val_loss")
87 + plt.plot(H.history["accuracy"], label="train_acc")
88 + plt.plot(H.history["val_accuracy"], label="val_acc")
89 + plt.title("Training Loss and Accuracy")
90 + plt.xlabel("Epoch #")
91 + plt.ylabel("Loss/Accuracy")
92 + plt.legend(loc="lower left")
93 + plt.savefig(plotPath)
...\ No newline at end of file ...\ No newline at end of file
1 +from gensim.models import KeyedVectors
2 +import text2vec
3 +import random
4 +from utils import *
5 +import matplotlib.pyplot as plt
6 +
7 +vectors_text_path = 'data/targets.txt' # w2v output file from model
8 +model = KeyedVectors.load_word2vec_format(vectors_text_path, binary=False)
9 +
10 +def compare(dir1, dir2):
11 + files = [f for f in readdir(dir1) if is_extension(f, 'py')]
12 +
13 + plt.ylabel('cos_sim')
14 + m = 10
15 + Mx = 0
16 + idx = 0
17 + L = len(files)
18 +
19 + for f in files:
20 + print(idx,"/",L)
21 + f2 = dir2 + f.split(dir1)[1]
22 +
23 + text1 = readAll(f)
24 + text2 = readAll(f2)
25 +
26 + similarity = text2vec.get_similarity(text1, text2, model, 384)
27 + m = min(m, similarity)
28 + Mx = max(Mx, similarity)
29 + plt.plot(idx, similarity, 'r.')
30 + idx += 1
31 +
32 + print("min:", m, "max:", Mx)
33 + plt.show()
34 +
35 +def compare2(path): # for merged dataset
36 + pairs = read_file(path + '/log.txt') # log file format: path_merged path_source1 path_source2
37 +
38 + plt.ylabel('cos_sim')
39 + m = 10
40 + Mx = 0
41 + idx = 0
42 + L = len(pairs)
43 + s1 = []
44 + s2 = []
45 +
46 + for p in pairs:
47 + print(idx,"/",L)
48 + arr = p.split(' ')
49 + C = path + '/' + arr[0].strip()
50 + A = arr[1].strip()
51 + B = arr[2].strip()
52 +
53 + text_A = readAll(A)
54 + text_B = readAll(B)
55 + text_C = readAll(C)
56 +
57 + similarity = text2vec.get_similarity(text_A, text_C, model, 384)
58 + m = min(m, similarity)
59 + Mx = max(Mx, similarity)
60 + s1.append(similarity)
61 +
62 + similarity = text2vec.get_similarity(text_B, text_C, model, 384)
63 + m = min(m, similarity)
64 + Mx = max(Mx, similarity)
65 + s2.append(similarity)
66 + idx += 1
67 +
68 + print("min:", m, "max:", Mx)
69 + plt.plot(s1, 'r.')
70 + plt.waitforbuttonpress()
71 +
72 + plt.cla()
73 + plt.plot(s2, 'b.')
74 + plt.show()
75 +
76 +def compare3(dir): # for original dataset compare. (n^2 here. beware of long processing
77 + files = [f for f in readdir(dir) if is_extension(f, 'py')]
78 +
79 + plt.ylabel('cos_sim')
80 + m = 10
81 + Mx = 0
82 + idx = 0
83 + L = len(files)
84 + data = []
85 +
86 + for f in files:
87 + print(idx,"/",L)
88 +
89 + text = readAll(f)
90 + data.append(text)
91 + idx += 1
92 +
93 + for i in range(L):
94 + print(i)
95 + j = i
96 + if i == 0:
97 + continue
98 + while j == i:
99 + j = random.choice(list(range(i)))
100 +
101 + similarity = text2vec.get_similarity(data[i], data[j], model, 384)
102 + m = min(m, similarity)
103 + Mx = max(Mx, similarity)
104 + plt.plot(i, similarity, 'r.')
105 +
106 + print("min:", m, "max:", Mx)
107 + plt.show()
108 +
109 +# Usage
110 +# compare('data/refined', 'data/obfuscated2')
111 +# compare2('data/merged')
112 +# compare3('data/refined')
...\ No newline at end of file ...\ No newline at end of file
1 +import re
2 +from utils import remove_string
3 +
4 +def parse_keywords(line):
5 + line = line.strip()
6 + line = remove_string(line)
7 + result = ''
8 +
9 + for c in line:
10 + if re.match('[A-Za-z_@0-9]', c):
11 + result += c
12 + else:
13 + result += ' '
14 +
15 + return result.split(' ')
...\ No newline at end of file ...\ No newline at end of file
1 +from file_parser import parse_keywords
2 +import numpy as np
3 +from scipy import spatial
4 +
5 +def avg_feature_vector(text, model, num_features, index2word_set):
6 + words = parse_keywords(text)
7 + feature_vec = np.zeros((num_features, ), dtype='float32')
8 + n_words = 0
9 + for word in words:
10 + if word in index2word_set:
11 + n_words += 1
12 + feature_vec = np.add(feature_vec, model[word])
13 + if (n_words > 0):
14 + feature_vec = np.divide(feature_vec, n_words)
15 + return feature_vec
16 +
17 +def get_similarity(text1, text2, model, num_features):
18 + index2word_set = set(model.index_to_key)
19 + s1 = avg_feature_vector(text1, model, num_features, index2word_set)
20 + s2 = avg_feature_vector(text2, model, num_features, index2word_set)
21 + return abs(1 - spatial.distance.cosine(s1, s2))
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +
3 +multi_line_comments = ["'''", '"""']
4 +
5 +def remove_string(line):
6 + strIn = False
7 + strCh = None
8 + result = ''
9 + i = 0
10 + L = len(line)
11 +
12 + while i < L:
13 + if i + 3 < L:
14 + if line[i:i+3] in multi_line_comments:
15 + if not strIn:
16 + strIn = True
17 + strCh = line[i:i+3]
18 + elif line[i:i+3] == strCh:
19 + strIn = False
20 +
21 + i += 2
22 + continue
23 +
24 + c = line[i]
25 + i += 1
26 +
27 + if c == '\'' or c == '\"':
28 + if not strIn:
29 + strIn = True
30 + strCh = c
31 + elif c == strCh:
32 + strIn = False
33 + continue
34 +
35 + if strIn:
36 + continue
37 +
38 + result += c
39 +
40 + return result
41 +
42 +def using_multi_string(line, index):
43 + line = line.strip()
44 + for comment in multi_line_comments:
45 + if line.find(comment, index) > 0:
46 + return True
47 + return False
48 +
49 +def remove_unnecessary_comments(lines):
50 + # Warning : cannot detect all multi-line comments, because it exactly is multi-line string.
51 +
52 + #TODO: multi line string parser will not work well when using strings (and comments, also) more than one.
53 + # ex) a = ''' d ''' + '''
54 + # abc ''' + '''
55 + # x'''
56 +
57 + result = []
58 + multi_line = False
59 + multi_string = False
60 + strCh = None
61 +
62 + for line in lines:
63 + find_str_index = 0
64 + if multi_string:
65 + if strCh in line:
66 + find_str_index = line.find(strCh) + 3
67 + multi_string = False
68 + strCh = None
69 +
70 + result.append(line)
71 + continue
72 +
73 + if multi_line: # parsing multi-line comments
74 + if strCh in line:
75 + multi_line = False
76 + strCh = None
77 + continue
78 +
79 + if using_multi_string(line, find_str_index):
80 + i1 = line.find(multi_line_comments[0])
81 + i2 = line.find(multi_line_comments[1])
82 +
83 + if i1 < 0:
84 + i1 = len(line) + 1
85 + if i2 < 0:
86 + i2 = len(line) + 1
87 +
88 + if i1 < i2:
89 + strCh = multi_line_comments[0]
90 + else:
91 + strCh = multi_line_comments[1]
92 +
93 + result.append(line)
94 + if line.count(strCh) % 2 != 0:
95 + multi_string = True
96 + continue
97 +
98 + code = line.strip()
99 +
100 + if code[:3] in multi_line_comments: # detect in-out of multi-line comments
101 + if code.count(code[:3]) % 2 != 0: # comment count in odd numbers (start or end of comment is in the line)
102 + multi_line = True
103 + strCh = code[:3]
104 + continue
105 +
106 + comment_index = line.find('#')
107 + if comment_index >= 0: # one line comment found
108 + line = line[:comment_index]
109 + line = line.rstrip() # remove rightmost spaces
110 +
111 + if len(line) == 0: # no code in this line
112 + continue
113 +
114 + result.append(line) # add to results
115 +
116 + return result
117 +
118 +def is_extension(f, ext):
119 + return os.path.splitext(f)[1][1:] == ext
120 +
121 +def _readdir_r(dirpath): # readdir for recursive
122 + ret = []
123 + for f in os.listdir(dirpath):
124 + ret.append(os.path.join(dirpath, f))
125 +
126 + return ret
127 +
128 +def readdir(path): # read files from the directory
129 + pathList = [path]
130 + result = []
131 + i = 0
132 +
133 + while i < len(pathList):
134 + f = pathList[i]
135 + if os.path.isdir(f):
136 + pathList += _readdir_r(f)
137 + else:
138 + result.append(f)
139 +
140 + i += 1
141 +
142 + return result
143 +
144 +def read_file(path):
145 + f = open(path, 'r', encoding='utf8')
146 + ret = f.readlines()
147 + f.close()
148 + return ret
149 +
150 +def write_file(path, lines):
151 + f = open(path, 'w', encoding='utf8')
152 +
153 + for line in lines:
154 + if '\n' in line:
155 + f.write(line)
156 + else:
157 + f.write(line + '\n')
158 + f.close()
159 +
160 +def readAll(path):
161 + f = open(path, 'r', encoding='utf8')
162 + ret = f.read()
163 + f.close()
164 + return ret
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type