Showing
37 changed files
with
3676 additions
and
0 deletions
code/code2vec/code2vec.py
0 → 100644
1 | +from vocabularies import VocabType | ||
2 | +from config import Config | ||
3 | +from interactive_predict import InteractivePredictor | ||
4 | +from model_base import Code2VecModelBase | ||
5 | + | ||
6 | + | ||
7 | +def load_model_dynamically(config: Config) -> Code2VecModelBase: | ||
8 | + assert config.DL_FRAMEWORK in {'tensorflow', 'keras'} | ||
9 | + if config.DL_FRAMEWORK == 'tensorflow': | ||
10 | + from tensorflow_model import Code2VecModel | ||
11 | + elif config.DL_FRAMEWORK == 'keras': | ||
12 | + from keras_model import Code2VecModel | ||
13 | + return Code2VecModel(config) | ||
14 | + | ||
15 | + | ||
16 | +if __name__ == '__main__': | ||
17 | + config = Config(set_defaults=True, load_from_args=True, verify=True) | ||
18 | + | ||
19 | + model = load_model_dynamically(config) | ||
20 | + | ||
21 | + if config.is_training: | ||
22 | + model.train() | ||
23 | + if config.SAVE_W2V is not None: | ||
24 | + model.save_word2vec_format(config.SAVE_W2V, VocabType.Token) | ||
25 | + config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V) | ||
26 | + if config.SAVE_T2V is not None: | ||
27 | + model.save_word2vec_format(config.SAVE_T2V, VocabType.Target) | ||
28 | + config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V) | ||
29 | + if (config.is_testing and not config.is_training) or config.RELEASE: | ||
30 | + eval_results = model.evaluate() | ||
31 | + if eval_results is not None: | ||
32 | + config.log( | ||
33 | + str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION))) | ||
34 | + if config.PREDICT: | ||
35 | + predictor = InteractivePredictor(config, model) | ||
36 | + predictor.predict() | ||
37 | + model.close_session() |
code/code2vec/common.py
0 → 100644
1 | +import re | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from itertools import takewhile, repeat | ||
5 | +from typing import List, Optional, Tuple, Iterable | ||
6 | +from datetime import datetime | ||
7 | +from collections import OrderedDict | ||
8 | + | ||
9 | + | ||
10 | +class common: | ||
11 | + | ||
12 | + @staticmethod | ||
13 | + def normalize_word(word): | ||
14 | + stripped = re.sub(r'[^a-zA-Z]', '', word) | ||
15 | + if len(stripped) == 0: | ||
16 | + return word.lower() | ||
17 | + else: | ||
18 | + return stripped.lower() | ||
19 | + | ||
20 | + @staticmethod | ||
21 | + def _load_vocab_from_histogram(path, min_count=0, start_from=0, return_counts=False): | ||
22 | + with open(path, 'r') as file: | ||
23 | + word_to_index = {} | ||
24 | + index_to_word = {} | ||
25 | + word_to_count = {} | ||
26 | + next_index = start_from | ||
27 | + for line in file: | ||
28 | + line_values = line.rstrip().split(' ') | ||
29 | + if len(line_values) != 2: | ||
30 | + continue | ||
31 | + word = line_values[0] | ||
32 | + count = int(line_values[1]) | ||
33 | + if count < min_count: | ||
34 | + continue | ||
35 | + if word in word_to_index: | ||
36 | + continue | ||
37 | + word_to_index[word] = next_index | ||
38 | + index_to_word[next_index] = word | ||
39 | + word_to_count[word] = count | ||
40 | + next_index += 1 | ||
41 | + result = word_to_index, index_to_word, next_index - start_from | ||
42 | + if return_counts: | ||
43 | + result = (*result, word_to_count) | ||
44 | + return result | ||
45 | + | ||
46 | + @staticmethod | ||
47 | + def load_vocab_from_histogram(path, min_count=0, start_from=0, max_size=None, return_counts=False): | ||
48 | + if max_size is not None: | ||
49 | + word_to_index, index_to_word, next_index, word_to_count = \ | ||
50 | + common._load_vocab_from_histogram(path, min_count, start_from, return_counts=True) | ||
51 | + if next_index <= max_size: | ||
52 | + results = (word_to_index, index_to_word, next_index) | ||
53 | + if return_counts: | ||
54 | + results = (*results, word_to_count) | ||
55 | + return results | ||
56 | + # Take min_count to be one plus the count of the max_size'th word | ||
57 | + min_count = sorted(word_to_count.values(), reverse=True)[max_size] + 1 | ||
58 | + return common._load_vocab_from_histogram(path, min_count, start_from, return_counts) | ||
59 | + | ||
60 | + @staticmethod | ||
61 | + def load_json(json_file): | ||
62 | + data = [] | ||
63 | + with open(json_file, 'r') as file: | ||
64 | + for line in file: | ||
65 | + current_program = common.process_single_json_line(line) | ||
66 | + if current_program is None: | ||
67 | + continue | ||
68 | + for element, scope in current_program.items(): | ||
69 | + data.append((element, scope)) | ||
70 | + return data | ||
71 | + | ||
72 | + @staticmethod | ||
73 | + def load_json_streaming(json_file): | ||
74 | + with open(json_file, 'r') as file: | ||
75 | + for line in file: | ||
76 | + current_program = common.process_single_json_line(line) | ||
77 | + if current_program is None: | ||
78 | + continue | ||
79 | + for element, scope in current_program.items(): | ||
80 | + yield (element, scope) | ||
81 | + | ||
82 | + @staticmethod | ||
83 | + def save_word2vec_file(output_file, index_to_word, vocab_embedding_matrix: np.ndarray): | ||
84 | + assert len(vocab_embedding_matrix.shape) == 2 | ||
85 | + vocab_size, embedding_dimension = vocab_embedding_matrix.shape | ||
86 | + output_file.write('%d %d\n' % (vocab_size, embedding_dimension)) | ||
87 | + for word_idx in range(0, vocab_size): | ||
88 | + assert word_idx in index_to_word | ||
89 | + word_str = index_to_word[word_idx] | ||
90 | + output_file.write(word_str + ' ') | ||
91 | + output_file.write(' '.join(map(str, vocab_embedding_matrix[word_idx])) + '\n') | ||
92 | + | ||
93 | + @staticmethod | ||
94 | + def calculate_max_contexts(file): | ||
95 | + contexts_per_word = common.process_test_input(file) | ||
96 | + return max( | ||
97 | + [max(l, default=0) for l in [[len(contexts) for contexts in prog.values()] for prog in contexts_per_word]], | ||
98 | + default=0) | ||
99 | + | ||
100 | + @staticmethod | ||
101 | + def binary_to_string(binary_string): | ||
102 | + return binary_string.decode("utf-8") | ||
103 | + | ||
104 | + @staticmethod | ||
105 | + def binary_to_string_list(binary_string_list): | ||
106 | + return [common.binary_to_string(w) for w in binary_string_list] | ||
107 | + | ||
108 | + @staticmethod | ||
109 | + def binary_to_string_matrix(binary_string_matrix): | ||
110 | + return [common.binary_to_string_list(l) for l in binary_string_matrix] | ||
111 | + | ||
112 | + @staticmethod | ||
113 | + def load_file_lines(path): | ||
114 | + with open(path, 'r') as f: | ||
115 | + return f.read().splitlines() | ||
116 | + | ||
117 | + @staticmethod | ||
118 | + def split_to_batches(data_lines, batch_size): | ||
119 | + for x in range(0, len(data_lines), batch_size): | ||
120 | + yield data_lines[x:x + batch_size] | ||
121 | + | ||
122 | + @staticmethod | ||
123 | + def legal_method_names_checker(special_words, name): | ||
124 | + return name != special_words.OOV and re.match(r'^[a-zA-Z_|]+[a-zA-Z_]+[a-zA-Z0-9_]+$', name) | ||
125 | + | ||
126 | + @staticmethod | ||
127 | + def filter_impossible_names(special_words, top_words): | ||
128 | + result = list(filter(lambda word: common.legal_method_names_checker(special_words, word), top_words)) | ||
129 | + return result | ||
130 | + | ||
131 | + @staticmethod | ||
132 | + def get_subtokens(str): | ||
133 | + return str.split('|') | ||
134 | + | ||
135 | + @staticmethod | ||
136 | + def parse_prediction_results(raw_prediction_results, unhash_dict, special_words, topk: int = 5) -> List['MethodPredictionResults']: | ||
137 | + prediction_results = [] | ||
138 | + for single_method_prediction in raw_prediction_results: | ||
139 | + current_method_prediction_results = MethodPredictionResults(single_method_prediction.original_name) | ||
140 | + for i, predicted in enumerate(single_method_prediction.topk_predicted_words): | ||
141 | + if predicted == special_words.OOV: | ||
142 | + continue | ||
143 | + suggestion_subtokens = common.get_subtokens(predicted) | ||
144 | + current_method_prediction_results.append_prediction( | ||
145 | + suggestion_subtokens, single_method_prediction.topk_predicted_words_scores[i].item()) | ||
146 | + topk_attention_per_context = [ | ||
147 | + (key, single_method_prediction.attention_per_context[key]) | ||
148 | + for key in sorted(single_method_prediction.attention_per_context, | ||
149 | + key=single_method_prediction.attention_per_context.get, reverse=True) | ||
150 | + ][:topk] | ||
151 | + for context, attention in topk_attention_per_context: | ||
152 | + token1, hashed_path, token2 = context | ||
153 | + if hashed_path in unhash_dict: | ||
154 | + unhashed_path = unhash_dict[hashed_path] | ||
155 | + current_method_prediction_results.append_attention_path(attention.item(), token1=token1, | ||
156 | + path=unhashed_path, token2=token2) | ||
157 | + prediction_results.append(current_method_prediction_results) | ||
158 | + return prediction_results | ||
159 | + | ||
160 | + @staticmethod | ||
161 | + def tf_get_first_true(bool_tensor: tf.Tensor) -> tf.Tensor: | ||
162 | + bool_tensor_as_int32 = tf.cast(bool_tensor, dtype=tf.int32) | ||
163 | + cumsum = tf.cumsum(bool_tensor_as_int32, axis=-1, exclusive=False) | ||
164 | + return tf.logical_and(tf.equal(cumsum, 1), bool_tensor) | ||
165 | + | ||
166 | + @staticmethod | ||
167 | + def count_lines_in_file(file_path: str): | ||
168 | + with open(file_path, 'rb') as f: | ||
169 | + bufgen = takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in repeat(None))) | ||
170 | + return sum(buf.count(b'\n') for buf in bufgen) | ||
171 | + | ||
172 | + @staticmethod | ||
173 | + def squeeze_single_batch_dimension_for_np_arrays(arrays): | ||
174 | + assert all(array is None or isinstance(array, np.ndarray) or isinstance(array, tf.Tensor) for array in arrays) | ||
175 | + return tuple( | ||
176 | + None if array is None else np.squeeze(array, axis=0) | ||
177 | + for array in arrays | ||
178 | + ) | ||
179 | + | ||
180 | + @staticmethod | ||
181 | + def get_first_match_word_from_top_predictions(special_words, original_name, top_predicted_words) -> Optional[Tuple[int, str]]: | ||
182 | + normalized_original_name = common.normalize_word(original_name) | ||
183 | + for suggestion_idx, predicted_word in enumerate(common.filter_impossible_names(special_words, top_predicted_words)): | ||
184 | + normalized_possible_suggestion = common.normalize_word(predicted_word) | ||
185 | + if normalized_original_name == normalized_possible_suggestion: | ||
186 | + return suggestion_idx, predicted_word | ||
187 | + return None | ||
188 | + | ||
189 | + @staticmethod | ||
190 | + def now_str(): | ||
191 | + return datetime.now().strftime("%Y%m%d-%H%M%S: ") | ||
192 | + | ||
193 | + @staticmethod | ||
194 | + def chunks(l, n): | ||
195 | + """Yield successive n-sized chunks from l.""" | ||
196 | + for i in range(0, len(l), n): | ||
197 | + yield l[i:i + n] | ||
198 | + | ||
199 | + @staticmethod | ||
200 | + def get_unique_list(lst: Iterable) -> list: | ||
201 | + return list(OrderedDict(((item, 0) for item in lst)).keys()) | ||
202 | + | ||
203 | + | ||
204 | +class MethodPredictionResults: | ||
205 | + def __init__(self, original_name): | ||
206 | + self.original_name = original_name | ||
207 | + self.predictions = list() | ||
208 | + self.attention_paths = list() | ||
209 | + | ||
210 | + def append_prediction(self, name, probability): | ||
211 | + self.predictions.append({'name': name, 'probability': probability}) | ||
212 | + | ||
213 | + def append_attention_path(self, attention_score, token1, path, token2): | ||
214 | + self.attention_paths.append({'score': attention_score, | ||
215 | + 'path': path, | ||
216 | + 'token1': token1, | ||
217 | + 'token2': token2}) |
code/code2vec/config.py
0 → 100644
1 | +from math import ceil | ||
2 | +from typing import Optional | ||
3 | +import logging | ||
4 | +from argparse import ArgumentParser | ||
5 | +import sys | ||
6 | +import os | ||
7 | + | ||
8 | + | ||
9 | +class Config: | ||
10 | + @classmethod | ||
11 | + def arguments_parser(cls) -> ArgumentParser: | ||
12 | + parser = ArgumentParser() | ||
13 | + parser.add_argument("-d", "--data", dest="data_path", | ||
14 | + help="path to preprocessed dataset", required=False) | ||
15 | + parser.add_argument("-te", "--test", dest="test_path", | ||
16 | + help="path to test file", metavar="FILE", required=False, default='') | ||
17 | + parser.add_argument("-s", "--save", dest="save_path", | ||
18 | + help="path to save the model file", metavar="FILE", required=False) | ||
19 | + parser.add_argument("-w2v", "--save_word2v", dest="save_w2v", | ||
20 | + help="path to save the tokens embeddings file", metavar="FILE", required=False) | ||
21 | + parser.add_argument("-t2v", "--save_target2v", dest="save_t2v", | ||
22 | + help="path to save the targets embeddings file", metavar="FILE", required=False) | ||
23 | + parser.add_argument("-l", "--load", dest="load_path", | ||
24 | + help="path to load the model from", metavar="FILE", required=False) | ||
25 | + parser.add_argument('--save_w2v', dest='save_w2v', required=False, | ||
26 | + help="save word (token) vectors in word2vec format") | ||
27 | + parser.add_argument('--save_t2v', dest='save_t2v', required=False, | ||
28 | + help="save target vectors in word2vec format") | ||
29 | + parser.add_argument('--export_code_vectors', action='store_true', required=False, | ||
30 | + help="export code vectors for the given examples") | ||
31 | + parser.add_argument('--release', action='store_true', | ||
32 | + help='if specified and loading a trained model, release the loaded model for a lower model ' | ||
33 | + 'size.') | ||
34 | + parser.add_argument('--predict', action='store_true', | ||
35 | + help='execute the interactive prediction shell') | ||
36 | + parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'], | ||
37 | + default='tensorflow', help="deep learning framework to use.") | ||
38 | + parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1, | ||
39 | + help="verbose mode (should be in {0,1,2}).") | ||
40 | + parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False, | ||
41 | + help="path to store logs into. if not given logs are not saved to file.") | ||
42 | + parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true', | ||
43 | + help='use tensorboard during training') | ||
44 | + return parser | ||
45 | + | ||
46 | + def set_defaults(self): | ||
47 | + self.NUM_TRAIN_EPOCHS = 20 | ||
48 | + self.SAVE_EVERY_EPOCHS = 1 | ||
49 | + self.TRAIN_BATCH_SIZE = 1024 | ||
50 | + self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE | ||
51 | + self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10 | ||
52 | + self.NUM_BATCHES_TO_LOG_PROGRESS = 100 | ||
53 | + self.NUM_TRAIN_BATCHES_TO_EVALUATE = 1800 | ||
54 | + self.READER_NUM_PARALLEL_BATCHES = 6 | ||
55 | + self.SHUFFLE_BUFFER_SIZE = 10000 | ||
56 | + self.CSV_BUFFER_SIZE = 100 * 1024 * 1024 | ||
57 | + self.MAX_TO_KEEP = 10 | ||
58 | + | ||
59 | + self.MAX_CONTEXTS = 200 | ||
60 | + self.MAX_TOKEN_VOCAB_SIZE = 1301136 | ||
61 | + self.MAX_TARGET_VOCAB_SIZE = 261245 | ||
62 | + self.MAX_PATH_VOCAB_SIZE = 911417 | ||
63 | + self.DEFAULT_EMBEDDINGS_SIZE = 128 | ||
64 | + self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE | ||
65 | + self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE | ||
66 | + self.CODE_VECTOR_SIZE = self.context_vector_size | ||
67 | + self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE | ||
68 | + self.DROPOUT_KEEP_RATE = 0.75 | ||
69 | + self.SEPARATE_OOV_AND_PAD = False | ||
70 | + | ||
71 | + def load_from_args(self): | ||
72 | + args = self.arguments_parser().parse_args() | ||
73 | + self.PREDICT = args.predict | ||
74 | + self.MODEL_SAVE_PATH = args.save_path | ||
75 | + self.MODEL_LOAD_PATH = args.load_path | ||
76 | + self.TRAIN_DATA_PATH_PREFIX = args.data_path | ||
77 | + self.TEST_DATA_PATH = args.test_path | ||
78 | + self.RELEASE = args.release | ||
79 | + self.EXPORT_CODE_VECTORS = args.export_code_vectors | ||
80 | + self.SAVE_W2V = args.save_w2v | ||
81 | + self.SAVE_T2V = args.save_t2v | ||
82 | + self.VERBOSE_MODE = args.verbose_mode | ||
83 | + self.LOGS_PATH = args.logs_path | ||
84 | + self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework | ||
85 | + self.USE_TENSORBOARD = args.use_tensorboard | ||
86 | + | ||
87 | + def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False): | ||
88 | + self.NUM_TRAIN_EPOCHS: int = 0 | ||
89 | + self.SAVE_EVERY_EPOCHS: int = 0 | ||
90 | + self.TRAIN_BATCH_SIZE: int = 0 | ||
91 | + self.TEST_BATCH_SIZE: int = 0 | ||
92 | + self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0 | ||
93 | + self.NUM_BATCHES_TO_LOG_PROGRESS: int = 0 | ||
94 | + self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0 | ||
95 | + self.READER_NUM_PARALLEL_BATCHES: int = 0 | ||
96 | + self.SHUFFLE_BUFFER_SIZE: int = 0 | ||
97 | + self.CSV_BUFFER_SIZE: int = 0 | ||
98 | + self.MAX_TO_KEEP: int = 0 | ||
99 | + | ||
100 | + self.MAX_CONTEXTS: int = 0 | ||
101 | + self.MAX_TOKEN_VOCAB_SIZE: int = 0 | ||
102 | + self.MAX_TARGET_VOCAB_SIZE: int = 0 | ||
103 | + self.MAX_PATH_VOCAB_SIZE: int = 0 | ||
104 | + self.DEFAULT_EMBEDDINGS_SIZE: int = 0 | ||
105 | + self.TOKEN_EMBEDDINGS_SIZE: int = 0 | ||
106 | + self.PATH_EMBEDDINGS_SIZE: int = 0 | ||
107 | + self.CODE_VECTOR_SIZE: int = 0 | ||
108 | + self.TARGET_EMBEDDINGS_SIZE: int = 0 | ||
109 | + self.DROPOUT_KEEP_RATE: float = 0 | ||
110 | + self.SEPARATE_OOV_AND_PAD: bool = False | ||
111 | + | ||
112 | + self.PREDICT: bool = False | ||
113 | + self.MODEL_SAVE_PATH: Optional[str] = None | ||
114 | + self.MODEL_LOAD_PATH: Optional[str] = None | ||
115 | + self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None | ||
116 | + self.TEST_DATA_PATH: Optional[str] = '' | ||
117 | + self.RELEASE: bool = False | ||
118 | + self.EXPORT_CODE_VECTORS: bool = False | ||
119 | + self.SAVE_W2V: Optional[str] = None | ||
120 | + self.SAVE_T2V: Optional[str] = None | ||
121 | + self.VERBOSE_MODE: int = 0 | ||
122 | + self.LOGS_PATH: Optional[str] = None | ||
123 | + self.DL_FRAMEWORK: str = 'tensorflow' | ||
124 | + self.USE_TENSORBOARD: bool = False | ||
125 | + | ||
126 | + self.NUM_TRAIN_EXAMPLES: int = 0 | ||
127 | + self.NUM_TEST_EXAMPLES: int = 0 | ||
128 | + | ||
129 | + self.__logger: Optional[logging.Logger] = None | ||
130 | + | ||
131 | + if set_defaults: | ||
132 | + self.set_defaults() | ||
133 | + if load_from_args: | ||
134 | + self.load_from_args() | ||
135 | + if verify: | ||
136 | + self.verify() | ||
137 | + | ||
138 | + @property | ||
139 | + def context_vector_size(self) -> int: | ||
140 | + return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE | ||
141 | + | ||
142 | + @property | ||
143 | + def is_training(self) -> bool: | ||
144 | + return bool(self.TRAIN_DATA_PATH_PREFIX) | ||
145 | + | ||
146 | + @property | ||
147 | + def is_loading(self) -> bool: | ||
148 | + return bool(self.MODEL_LOAD_PATH) | ||
149 | + | ||
150 | + @property | ||
151 | + def is_saving(self) -> bool: | ||
152 | + return bool(self.MODEL_SAVE_PATH) | ||
153 | + | ||
154 | + @property | ||
155 | + def is_testing(self) -> bool: | ||
156 | + return bool(self.TEST_DATA_PATH) | ||
157 | + | ||
158 | + @property | ||
159 | + def train_steps_per_epoch(self) -> int: | ||
160 | + return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0 | ||
161 | + | ||
162 | + @property | ||
163 | + def test_steps(self) -> int: | ||
164 | + return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0 | ||
165 | + | ||
166 | + def data_path(self, is_evaluating: bool = False): | ||
167 | + return self.TEST_DATA_PATH if is_evaluating else self.train_data_path | ||
168 | + | ||
169 | + def batch_size(self, is_evaluating: bool = False): | ||
170 | + return self.TEST_BATCH_SIZE if is_evaluating else self.TRAIN_BATCH_SIZE # take min with NUM_TRAIN_EXAMPLES? | ||
171 | + | ||
172 | + @property | ||
173 | + def train_data_path(self) -> Optional[str]: | ||
174 | + if not self.is_training: | ||
175 | + return None | ||
176 | + return '{}.train.c2v'.format(self.TRAIN_DATA_PATH_PREFIX) | ||
177 | + | ||
178 | + @property | ||
179 | + def word_freq_dict_path(self) -> Optional[str]: | ||
180 | + if not self.is_training: | ||
181 | + return None | ||
182 | + return '{}.dict.c2v'.format(self.TRAIN_DATA_PATH_PREFIX) | ||
183 | + | ||
184 | + @classmethod | ||
185 | + def get_vocabularies_path_from_model_path(cls, model_file_path: str) -> str: | ||
186 | + vocabularies_save_file_name = "dictionaries.bin" | ||
187 | + return '/'.join(model_file_path.split('/')[:-1] + [vocabularies_save_file_name]) | ||
188 | + | ||
189 | + @classmethod | ||
190 | + def get_entire_model_path(cls, model_path: str) -> str: | ||
191 | + return model_path + '__entire-model' | ||
192 | + | ||
193 | + @classmethod | ||
194 | + def get_model_weights_path(cls, model_path: str) -> str: | ||
195 | + return model_path + '__only-weights' | ||
196 | + | ||
197 | + @property | ||
198 | + def model_load_dir(self): | ||
199 | + return '/'.join(self.MODEL_LOAD_PATH.split('/')[:-1]) | ||
200 | + | ||
201 | + @property | ||
202 | + def entire_model_load_path(self) -> Optional[str]: | ||
203 | + if not self.is_loading: | ||
204 | + return None | ||
205 | + return self.get_entire_model_path(self.MODEL_LOAD_PATH) | ||
206 | + | ||
207 | + @property | ||
208 | + def model_weights_load_path(self) -> Optional[str]: | ||
209 | + if not self.is_loading: | ||
210 | + return None | ||
211 | + return self.get_model_weights_path(self.MODEL_LOAD_PATH) | ||
212 | + | ||
213 | + @property | ||
214 | + def entire_model_save_path(self) -> Optional[str]: | ||
215 | + if not self.is_saving: | ||
216 | + return None | ||
217 | + return self.get_entire_model_path(self.MODEL_SAVE_PATH) | ||
218 | + | ||
219 | + @property | ||
220 | + def model_weights_save_path(self) -> Optional[str]: | ||
221 | + if not self.is_saving: | ||
222 | + return None | ||
223 | + return self.get_model_weights_path(self.MODEL_SAVE_PATH) | ||
224 | + | ||
225 | + def verify(self): | ||
226 | + if not self.is_training and not self.is_loading: | ||
227 | + raise ValueError("Must train or load a model.") | ||
228 | + if self.is_loading and not os.path.isdir(self.model_load_dir): | ||
229 | + raise ValueError("Model load dir `{model_load_dir}` does not exist.".format( | ||
230 | + model_load_dir=self.model_load_dir)) | ||
231 | + | ||
232 | + def __iter__(self): | ||
233 | + for attr_name in dir(self): | ||
234 | + if attr_name.startswith("__"): | ||
235 | + continue | ||
236 | + try: | ||
237 | + attr_value = getattr(self, attr_name, None) | ||
238 | + except: | ||
239 | + attr_value = None | ||
240 | + if callable(attr_value): | ||
241 | + continue | ||
242 | + yield attr_name, attr_value | ||
243 | + | ||
244 | + def get_logger(self) -> logging.Logger: | ||
245 | + if self.__logger is None: | ||
246 | + self.__logger = logging.getLogger('code2vec') | ||
247 | + self.__logger.setLevel(logging.INFO) | ||
248 | + self.__logger.handlers = [] | ||
249 | + self.__logger.propagate = 0 | ||
250 | + | ||
251 | + formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') | ||
252 | + | ||
253 | + if self.VERBOSE_MODE >= 1: | ||
254 | + ch = logging.StreamHandler(sys.stdout) | ||
255 | + ch.setLevel(logging.INFO) | ||
256 | + ch.setFormatter(formatter) | ||
257 | + self.__logger.addHandler(ch) | ||
258 | + | ||
259 | + if self.LOGS_PATH: | ||
260 | + fh = logging.FileHandler(self.LOGS_PATH) | ||
261 | + fh.setLevel(logging.INFO) | ||
262 | + fh.setFormatter(formatter) | ||
263 | + self.__logger.addHandler(fh) | ||
264 | + | ||
265 | + return self.__logger | ||
266 | + | ||
267 | + def log(self, msg): | ||
268 | + self.get_logger().info(msg) |
code/code2vec/interactive_predict.py
0 → 100644
1 | +import traceback | ||
2 | + | ||
3 | +from common import common | ||
4 | +from py_extractor import PyExtractor | ||
5 | + | ||
6 | +SHOW_TOP_CONTEXTS = 10 | ||
7 | +MAX_PATH_LENGTH = 8 | ||
8 | +MAX_PATH_WIDTH = 2 | ||
9 | +input_filename = 'test.c2v' | ||
10 | + | ||
11 | + | ||
12 | +class InteractivePredictor: | ||
13 | + exit_keywords = ['exit', 'quit', 'q'] | ||
14 | + | ||
15 | + def __init__(self, config, model): | ||
16 | + model.predict([]) | ||
17 | + self.model = model | ||
18 | + self.config = config | ||
19 | + self.path_extractor = PyExtractor(config) | ||
20 | + | ||
21 | + def predict(self): | ||
22 | + print('Starting interactive prediction...') | ||
23 | + while True: | ||
24 | + print('Modify the file: "%s" and press any key when ready, or "q" / "quit" / "exit" to exit' % input_filename) | ||
25 | + user_input = input() | ||
26 | + if user_input.lower() in self.exit_keywords: | ||
27 | + print('Exiting...') | ||
28 | + return | ||
29 | + try: | ||
30 | + predict_lines, hash_to_string_dict = self.path_extractor.extract_paths(input_filename) | ||
31 | + except ValueError as e: | ||
32 | + print(e) | ||
33 | + continue | ||
34 | + raw_prediction_results = self.model.predict(predict_lines) | ||
35 | + method_prediction_results = common.parse_prediction_results( | ||
36 | + raw_prediction_results, hash_to_string_dict, | ||
37 | + self.model.vocabs.target_vocab.special_words, topk=SHOW_TOP_CONTEXTS) | ||
38 | + for raw_prediction, method_prediction in zip(raw_prediction_results, method_prediction_results): | ||
39 | + print('Original name:\t' + method_prediction.original_name) | ||
40 | + for name_prob_pair in method_prediction.predictions: | ||
41 | + print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name'])) | ||
42 | + print('Attention:') | ||
43 | + for attention_obj in method_prediction.attention_paths: | ||
44 | + print('%f\tcontext: %s,%s,%s' % ( | ||
45 | + attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2'])) | ||
46 | + if self.config.EXPORT_CODE_VECTORS: | ||
47 | + print('Code vector:') | ||
48 | + print(' '.join(map(str, raw_prediction.code_vector))) |
code/code2vec/model_base.py
0 → 100644
1 | +import numpy as np | ||
2 | +import abc | ||
3 | +import os | ||
4 | +from typing import NamedTuple, Optional, List, Dict, Tuple, Iterable | ||
5 | + | ||
6 | +from common import common | ||
7 | +from vocabularies import Code2VecVocabs, VocabType | ||
8 | +from config import Config | ||
9 | + | ||
10 | + | ||
11 | +class ModelEvaluationResults(NamedTuple): | ||
12 | + topk_acc: float | ||
13 | + subtoken_precision: float | ||
14 | + subtoken_recall: float | ||
15 | + subtoken_f1: float | ||
16 | + loss: Optional[float] = None | ||
17 | + | ||
18 | + def __str__(self): | ||
19 | + res_str = 'topk_acc: {topk_acc}, precision: {precision}, recall: {recall}, F1: {f1}'.format( | ||
20 | + topk_acc=self.topk_acc, | ||
21 | + precision=self.subtoken_precision, | ||
22 | + recall=self.subtoken_recall, | ||
23 | + f1=self.subtoken_f1) | ||
24 | + if self.loss is not None: | ||
25 | + res_str = ('loss: {}, '.format(self.loss)) + res_str | ||
26 | + return res_str | ||
27 | + | ||
28 | + | ||
29 | +class ModelPredictionResults(NamedTuple): | ||
30 | + original_name: str | ||
31 | + topk_predicted_words: np.ndarray | ||
32 | + topk_predicted_words_scores: np.ndarray | ||
33 | + attention_per_context: Dict[Tuple[str, str, str], float] | ||
34 | + code_vector: Optional[np.ndarray] = None | ||
35 | + | ||
36 | + | ||
37 | +class Code2VecModelBase(abc.ABC): | ||
38 | + def __init__(self, config: Config): | ||
39 | + self.config = config | ||
40 | + self.config.verify() | ||
41 | + | ||
42 | + self._log_creating_model() | ||
43 | + | ||
44 | + if not config.RELEASE: | ||
45 | + self._init_num_of_examples() | ||
46 | + self._log_model_configuration() | ||
47 | + self.vocabs = Code2VecVocabs(config) | ||
48 | + self.vocabs.target_vocab.get_index_to_word_lookup_table() | ||
49 | + self._load_or_create_inner_model() | ||
50 | + self._initialize() | ||
51 | + | ||
52 | + def _log_creating_model(self): | ||
53 | + self.log('') | ||
54 | + self.log('') | ||
55 | + self.log('---------------------------------------------------------------------') | ||
56 | + self.log('---------------------------------------------------------------------') | ||
57 | + self.log('---------------------- Creating code2vec model ----------------------') | ||
58 | + self.log('---------------------------------------------------------------------') | ||
59 | + self.log('---------------------------------------------------------------------') | ||
60 | + | ||
61 | + def _log_model_configuration(self): | ||
62 | + self.log('---------------------------------------------------------------------') | ||
63 | + self.log('----------------- Configuration - Hyper Parameters ------------------') | ||
64 | + longest_param_name_len = max(len(param_name) for param_name, _ in self.config) | ||
65 | + for param_name, param_val in self.config: | ||
66 | + self.log('{name: <{name_len}}{val}'.format( | ||
67 | + name=param_name, val=param_val, name_len=longest_param_name_len+2)) | ||
68 | + self.log('---------------------------------------------------------------------') | ||
69 | + | ||
70 | + @property | ||
71 | + def logger(self): | ||
72 | + return self.config.get_logger() | ||
73 | + | ||
74 | + def log(self, msg): | ||
75 | + self.logger.info(msg) | ||
76 | + | ||
77 | + def _init_num_of_examples(self): | ||
78 | + self.log('Checking number of examples ...') | ||
79 | + if self.config.is_training: | ||
80 | + self.config.NUM_TRAIN_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.train_data_path) | ||
81 | + self.log(' Number of train examples: {}'.format(self.config.NUM_TRAIN_EXAMPLES)) | ||
82 | + if self.config.is_testing: | ||
83 | + self.config.NUM_TEST_EXAMPLES = self._get_num_of_examples_for_dataset(self.config.TEST_DATA_PATH) | ||
84 | + self.log(' Number of test examples: {}'.format(self.config.NUM_TEST_EXAMPLES)) | ||
85 | + | ||
86 | + @staticmethod | ||
87 | + def _get_num_of_examples_for_dataset(dataset_path: str) -> int: | ||
88 | + dataset_num_examples_file_path = dataset_path + '.num_examples' | ||
89 | + if os.path.isfile(dataset_num_examples_file_path): | ||
90 | + with open(dataset_num_examples_file_path, 'r') as file: | ||
91 | + num_examples_in_dataset = int(file.readline()) | ||
92 | + else: | ||
93 | + num_examples_in_dataset = common.count_lines_in_file(dataset_path) | ||
94 | + with open(dataset_num_examples_file_path, 'w') as file: | ||
95 | + file.write(str(num_examples_in_dataset)) | ||
96 | + return num_examples_in_dataset | ||
97 | + | ||
98 | + def load_or_build(self): | ||
99 | + self.vocabs = Code2VecVocabs(self.config) | ||
100 | + self._load_or_create_inner_model() | ||
101 | + | ||
102 | + def save(self, model_save_path=None): | ||
103 | + if model_save_path is None: | ||
104 | + model_save_path = self.config.MODEL_SAVE_PATH | ||
105 | + model_save_dir = '/'.join(model_save_path.split('/')[:-1]) | ||
106 | + if not os.path.isdir(model_save_dir): | ||
107 | + os.makedirs(model_save_dir, exist_ok=True) | ||
108 | + self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path)) | ||
109 | + self._save_inner_model(model_save_path) | ||
110 | + | ||
111 | + def _write_code_vectors(self, file, code_vectors): | ||
112 | + for vec in code_vectors: | ||
113 | + file.write(' '.join(map(str, vec)) + '\n') | ||
114 | + | ||
115 | + def _get_attention_weight_per_context( | ||
116 | + self, path_source_strings: Iterable[str], path_strings: Iterable[str], path_target_strings: Iterable[str], | ||
117 | + attention_weights: Iterable[float]) -> Dict[Tuple[str, str, str], float]: | ||
118 | + attention_weights = np.squeeze(attention_weights, axis=-1) # (max_contexts, ) | ||
119 | + attention_per_context: Dict[Tuple[str, str, str], float] = {} | ||
120 | + | ||
121 | + for path_source, path, path_target, weight in \ | ||
122 | + zip(path_source_strings, path_strings, path_target_strings, attention_weights): | ||
123 | + string_context_triplet = (common.binary_to_string(path_source), | ||
124 | + common.binary_to_string(path), | ||
125 | + common.binary_to_string(path_target)) | ||
126 | + attention_per_context[string_context_triplet] = weight | ||
127 | + return attention_per_context | ||
128 | + | ||
129 | + def close_session(self): | ||
130 | + pass | ||
131 | + | ||
132 | + @abc.abstractmethod | ||
133 | + def train(self): | ||
134 | + ... | ||
135 | + | ||
136 | + @abc.abstractmethod | ||
137 | + def evaluate(self) -> Optional[ModelEvaluationResults]: | ||
138 | + ... | ||
139 | + | ||
140 | + @abc.abstractmethod | ||
141 | + def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]: | ||
142 | + ... | ||
143 | + | ||
144 | + @abc.abstractmethod | ||
145 | + def _save_inner_model(self, path): | ||
146 | + ... | ||
147 | + | ||
148 | + def _load_or_create_inner_model(self): | ||
149 | + if self.config.is_loading: | ||
150 | + self._load_inner_model() | ||
151 | + else: | ||
152 | + self._create_inner_model() | ||
153 | + | ||
154 | + @abc.abstractmethod | ||
155 | + def _load_inner_model(self): | ||
156 | + ... | ||
157 | + | ||
158 | + def _create_inner_model(self): | ||
159 | + pass | ||
160 | + | ||
161 | + def _initialize(self): | ||
162 | + pass | ||
163 | + | ||
164 | + @abc.abstractmethod | ||
165 | + def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray: | ||
166 | + ... | ||
167 | + | ||
168 | + def save_word2vec_format(self, dest_save_path: str, vocab_type: VocabType): | ||
169 | + if vocab_type not in VocabType: | ||
170 | + raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.') | ||
171 | + vocab_embedding_matrix = self._get_vocab_embedding_as_np_array(vocab_type) | ||
172 | + index_to_word = self.vocabs.get(vocab_type).index_to_word | ||
173 | + with open(dest_save_path, 'w') as words_file: | ||
174 | + common.save_word2vec_file(words_file, index_to_word, vocab_embedding_matrix) |
code/code2vec/path_context_reader.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +from typing import Dict, Tuple, NamedTuple, Union, Optional, Iterable | ||
3 | +from config import Config | ||
4 | +from vocabularies import Code2VecVocabs | ||
5 | +import abc | ||
6 | +from functools import reduce | ||
7 | +from enum import Enum | ||
8 | + | ||
9 | + | ||
10 | +class EstimatorAction(Enum): | ||
11 | + Train = 'train' | ||
12 | + Evaluate = 'evaluate' | ||
13 | + Predict = 'predict' | ||
14 | + | ||
15 | + @property | ||
16 | + def is_train(self): | ||
17 | + return self is EstimatorAction.Train | ||
18 | + | ||
19 | + @property | ||
20 | + def is_evaluate(self): | ||
21 | + return self is EstimatorAction.Evaluate | ||
22 | + | ||
23 | + @property | ||
24 | + def is_predict(self): | ||
25 | + return self is EstimatorAction.Predict | ||
26 | + | ||
27 | + @property | ||
28 | + def is_evaluate_or_predict(self): | ||
29 | + return self.is_evaluate or self.is_predict | ||
30 | + | ||
31 | + | ||
32 | +class ReaderInputTensors(NamedTuple): | ||
33 | + path_source_token_indices: tf.Tensor | ||
34 | + path_indices: tf.Tensor | ||
35 | + path_target_token_indices: tf.Tensor | ||
36 | + context_valid_mask: tf.Tensor | ||
37 | + target_index: Optional[tf.Tensor] = None | ||
38 | + target_string: Optional[tf.Tensor] = None | ||
39 | + path_source_token_strings: Optional[tf.Tensor] = None | ||
40 | + path_strings: Optional[tf.Tensor] = None | ||
41 | + path_target_token_strings: Optional[tf.Tensor] = None | ||
42 | + | ||
43 | + | ||
44 | +class ModelInputTensorsFormer(abc.ABC): | ||
45 | + @abc.abstractmethod | ||
46 | + def to_model_input_form(self, input_tensors: ReaderInputTensors): | ||
47 | + ... | ||
48 | + | ||
49 | + @abc.abstractmethod | ||
50 | + def from_model_input_form(self, input_row) -> ReaderInputTensors: | ||
51 | + ... | ||
52 | + | ||
53 | + | ||
54 | +class PathContextReader: | ||
55 | + def __init__(self, | ||
56 | + vocabs: Code2VecVocabs, | ||
57 | + config: Config, | ||
58 | + model_input_tensors_former: ModelInputTensorsFormer, | ||
59 | + estimator_action: EstimatorAction, | ||
60 | + repeat_endlessly: bool = False): | ||
61 | + self.vocabs = vocabs | ||
62 | + self.config = config | ||
63 | + self.model_input_tensors_former = model_input_tensors_former | ||
64 | + self.estimator_action = estimator_action | ||
65 | + self.repeat_endlessly = repeat_endlessly | ||
66 | + self.CONTEXT_PADDING = ','.join([self.vocabs.token_vocab.special_words.PAD, | ||
67 | + self.vocabs.path_vocab.special_words.PAD, | ||
68 | + self.vocabs.token_vocab.special_words.PAD]) | ||
69 | + self.csv_record_defaults = [[self.vocabs.target_vocab.special_words.OOV]] + \ | ||
70 | + ([[self.CONTEXT_PADDING]] * self.config.MAX_CONTEXTS) | ||
71 | + | ||
72 | + self.create_needed_vocabs_lookup_tables(self.vocabs) | ||
73 | + | ||
74 | + self._dataset: Optional[tf.data.Dataset] = None | ||
75 | + | ||
76 | + @classmethod | ||
77 | + def create_needed_vocabs_lookup_tables(cls, vocabs: Code2VecVocabs): | ||
78 | + vocabs.token_vocab.get_word_to_index_lookup_table() | ||
79 | + vocabs.path_vocab.get_word_to_index_lookup_table() | ||
80 | + vocabs.target_vocab.get_word_to_index_lookup_table() | ||
81 | + | ||
82 | + @tf.function | ||
83 | + def process_input_row(self, row_placeholder): | ||
84 | + parts = tf.io.decode_csv( | ||
85 | + row_placeholder, record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False) | ||
86 | + tensors = self._map_raw_dataset_row_to_input_tensors(*parts) | ||
87 | + | ||
88 | + tensors_expanded = ReaderInputTensors( | ||
89 | + **{name: None if tensor is None else tf.expand_dims(tensor, axis=0) | ||
90 | + for name, tensor in tensors._asdict().items()}) | ||
91 | + return self.model_input_tensors_former.to_model_input_form(tensors_expanded) | ||
92 | + | ||
93 | + def process_and_iterate_input_from_data_lines(self, input_data_lines: Iterable) -> Iterable: | ||
94 | + for data_row in input_data_lines: | ||
95 | + processed_row = self.process_input_row(data_row) | ||
96 | + yield processed_row | ||
97 | + | ||
98 | + def get_dataset(self, input_data_rows: Optional = None) -> tf.data.Dataset: | ||
99 | + if self._dataset is None: | ||
100 | + self._dataset = self._create_dataset_pipeline(input_data_rows) | ||
101 | + return self._dataset | ||
102 | + | ||
103 | + def _create_dataset_pipeline(self, input_data_rows: Optional = None) -> tf.data.Dataset: | ||
104 | + if input_data_rows is None: | ||
105 | + assert not self.estimator_action.is_predict | ||
106 | + dataset = tf.data.experimental.CsvDataset( | ||
107 | + self.config.data_path(is_evaluating=self.estimator_action.is_evaluate), | ||
108 | + record_defaults=self.csv_record_defaults, field_delim=' ', use_quote_delim=False, | ||
109 | + buffer_size=self.config.CSV_BUFFER_SIZE) | ||
110 | + else: | ||
111 | + dataset = tf.data.Dataset.from_tensor_slices(input_data_rows) | ||
112 | + dataset = dataset.map( | ||
113 | + lambda input_line: tf.io.decode_csv( | ||
114 | + tf.reshape(tf.cast(input_line, tf.string), ()), | ||
115 | + record_defaults=self.csv_record_defaults, | ||
116 | + field_delim=' ', use_quote_delim=False)) | ||
117 | + | ||
118 | + if self.repeat_endlessly: | ||
119 | + dataset = dataset.repeat() | ||
120 | + if self.estimator_action.is_train: | ||
121 | + if not self.repeat_endlessly and self.config.NUM_TRAIN_EPOCHS > 1: | ||
122 | + dataset = dataset.repeat(self.config.NUM_TRAIN_EPOCHS) | ||
123 | + dataset = dataset.shuffle(self.config.SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True) | ||
124 | + | ||
125 | + dataset = dataset.map(self._map_raw_dataset_row_to_expected_model_input_form, | ||
126 | + num_parallel_calls=self.config.READER_NUM_PARALLEL_BATCHES) | ||
127 | + batch_size = self.config.batch_size(is_evaluating=self.estimator_action.is_evaluate) | ||
128 | + if self.estimator_action.is_predict: | ||
129 | + dataset = dataset.batch(1) | ||
130 | + else: | ||
131 | + dataset = dataset.filter(self._filter_input_rows) | ||
132 | + dataset = dataset.batch(batch_size) | ||
133 | + | ||
134 | + dataset = dataset.prefetch(buffer_size=40) | ||
135 | + return dataset | ||
136 | + | ||
137 | + def _filter_input_rows(self, *row_parts) -> tf.bool: | ||
138 | + row_parts = self.model_input_tensors_former.from_model_input_form(row_parts) | ||
139 | + | ||
140 | + any_word_valid_mask_per_context_part = [ | ||
141 | + tf.not_equal(tf.reduce_max(row_parts.path_source_token_indices, axis=0), | ||
142 | + self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]), | ||
143 | + tf.not_equal(tf.reduce_max(row_parts.path_target_token_indices, axis=0), | ||
144 | + self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]), | ||
145 | + tf.not_equal(tf.reduce_max(row_parts.path_indices, axis=0), | ||
146 | + self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])] | ||
147 | + any_contexts_is_valid = reduce(tf.logical_or, any_word_valid_mask_per_context_part) | ||
148 | + | ||
149 | + if self.estimator_action.is_evaluate: | ||
150 | + cond = any_contexts_is_valid | ||
151 | + else: | ||
152 | + word_is_valid = tf.greater( | ||
153 | + row_parts.target_index, self.vocabs.target_vocab.word_to_index[self.vocabs.target_vocab.special_words.OOV]) # scalar | ||
154 | + cond = tf.logical_and(word_is_valid, any_contexts_is_valid) | ||
155 | + | ||
156 | + return cond | ||
157 | + | ||
158 | + def _map_raw_dataset_row_to_expected_model_input_form(self, *row_parts) -> \ | ||
159 | + Tuple[Union[tf.Tensor, Tuple[tf.Tensor, ...], Dict[str, tf.Tensor]], ...]: | ||
160 | + tensors = self._map_raw_dataset_row_to_input_tensors(*row_parts) | ||
161 | + return self.model_input_tensors_former.to_model_input_form(tensors) | ||
162 | + | ||
163 | + def _map_raw_dataset_row_to_input_tensors(self, *row_parts) -> ReaderInputTensors: | ||
164 | + row_parts = list(row_parts) | ||
165 | + target_str = row_parts[0] | ||
166 | + target_index = self.vocabs.target_vocab.lookup_index(target_str) | ||
167 | + | ||
168 | + contexts_str = tf.stack(row_parts[1:(self.config.MAX_CONTEXTS + 1)], axis=0) | ||
169 | + split_contexts = tf.compat.v1.string_split(contexts_str, sep=',', skip_empty=False) | ||
170 | + sparse_split_contexts = tf.sparse.SparseTensor( | ||
171 | + indices=split_contexts.indices, values=split_contexts.values, dense_shape=[self.config.MAX_CONTEXTS, 3]) | ||
172 | + dense_split_contexts = tf.reshape( | ||
173 | + tf.sparse.to_dense(sp_input=sparse_split_contexts, default_value=self.vocabs.token_vocab.special_words.PAD), | ||
174 | + shape=[self.config.MAX_CONTEXTS, 3]) | ||
175 | + | ||
176 | + path_source_token_strings = tf.squeeze( | ||
177 | + tf.slice(dense_split_contexts, begin=[0, 0], size=[self.config.MAX_CONTEXTS, 1]), axis=1) | ||
178 | + path_strings = tf.squeeze( | ||
179 | + tf.slice(dense_split_contexts, begin=[0, 1], size=[self.config.MAX_CONTEXTS, 1]), axis=1) | ||
180 | + path_target_token_strings = tf.squeeze( | ||
181 | + tf.slice(dense_split_contexts, begin=[0, 2], size=[self.config.MAX_CONTEXTS, 1]), axis=1) | ||
182 | + | ||
183 | + path_source_token_indices = self.vocabs.token_vocab.lookup_index(path_source_token_strings) | ||
184 | + path_indices = self.vocabs.path_vocab.lookup_index(path_strings) | ||
185 | + path_target_token_indices = self.vocabs.token_vocab.lookup_index(path_target_token_strings) | ||
186 | + | ||
187 | + valid_word_mask_per_context_part = [ | ||
188 | + tf.not_equal(path_source_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]), | ||
189 | + tf.not_equal(path_target_token_indices, self.vocabs.token_vocab.word_to_index[self.vocabs.token_vocab.special_words.PAD]), | ||
190 | + tf.not_equal(path_indices, self.vocabs.path_vocab.word_to_index[self.vocabs.path_vocab.special_words.PAD])] | ||
191 | + context_valid_mask = tf.cast(reduce(tf.logical_or, valid_word_mask_per_context_part), dtype=tf.float32) | ||
192 | + | ||
193 | + return ReaderInputTensors( | ||
194 | + path_source_token_indices=path_source_token_indices, | ||
195 | + path_indices=path_indices, | ||
196 | + path_target_token_indices=path_target_token_indices, | ||
197 | + context_valid_mask=context_valid_mask, | ||
198 | + target_index=target_index, | ||
199 | + target_string=target_str, | ||
200 | + path_source_token_strings=path_source_token_strings, | ||
201 | + path_strings=path_strings, | ||
202 | + path_target_token_strings=path_target_token_strings | ||
203 | + ) |
code/code2vec/preprocess.py
0 → 100644
1 | +import random | ||
2 | +from argparse import ArgumentParser | ||
3 | +import common | ||
4 | +import pickle | ||
5 | + | ||
6 | +def save_dictionaries(dataset_name, word_to_count, path_to_count, target_to_count, | ||
7 | + num_training_examples): | ||
8 | + save_dict_file_path = '{}.dict.c2v'.format(dataset_name) | ||
9 | + with open(save_dict_file_path, 'wb') as file: | ||
10 | + pickle.dump(word_to_count, file) | ||
11 | + pickle.dump(path_to_count, file) | ||
12 | + pickle.dump(target_to_count, file) | ||
13 | + pickle.dump(num_training_examples, file) | ||
14 | + print('Dictionaries saved to: {}'.format(save_dict_file_path)) | ||
15 | + | ||
16 | + | ||
17 | +def process_file(file_path, data_file_role, dataset_name, word_to_count, path_to_count, max_contexts): | ||
18 | + sum_total = 0 | ||
19 | + sum_sampled = 0 | ||
20 | + total = 0 | ||
21 | + empty = 0 | ||
22 | + max_unfiltered = 0 | ||
23 | + output_path = '{}.{}.c2v'.format(dataset_name, data_file_role) | ||
24 | + with open(output_path, 'w') as outfile: | ||
25 | + with open(file_path, 'r') as file: | ||
26 | + for line in file: | ||
27 | + parts = line.rstrip('\n').split(' ') | ||
28 | + target_name = parts[0] | ||
29 | + contexts = parts[1:] | ||
30 | + | ||
31 | + if len(contexts) > max_unfiltered: | ||
32 | + max_unfiltered = len(contexts) | ||
33 | + sum_total += len(contexts) | ||
34 | + | ||
35 | + if len(contexts) > max_contexts: | ||
36 | + context_parts = [c.split(',') for c in contexts] | ||
37 | + full_found_contexts = [c for i, c in enumerate(contexts) | ||
38 | + if context_full_found(context_parts[i], word_to_count, path_to_count)] | ||
39 | + partial_found_contexts = [c for i, c in enumerate(contexts) | ||
40 | + if context_partial_found(context_parts[i], word_to_count, path_to_count) | ||
41 | + and not context_full_found(context_parts[i], word_to_count, | ||
42 | + path_to_count)] | ||
43 | + if len(full_found_contexts) > max_contexts: | ||
44 | + contexts = random.sample(full_found_contexts, max_contexts) | ||
45 | + elif len(full_found_contexts) <= max_contexts \ | ||
46 | + and len(full_found_contexts) + len(partial_found_contexts) > max_contexts: | ||
47 | + contexts = full_found_contexts + \ | ||
48 | + random.sample(partial_found_contexts, max_contexts - len(full_found_contexts)) | ||
49 | + else: | ||
50 | + contexts = full_found_contexts + partial_found_contexts | ||
51 | + | ||
52 | + if len(contexts) == 0: | ||
53 | + empty += 1 | ||
54 | + continue | ||
55 | + | ||
56 | + sum_sampled += len(contexts) | ||
57 | + | ||
58 | + csv_padding = " " * (max_contexts - len(contexts)) | ||
59 | + outfile.write(target_name + ' ' + " ".join(contexts) + csv_padding + '\n') | ||
60 | + total += 1 | ||
61 | + | ||
62 | + print('File: ' + file_path) | ||
63 | + print('Average total contexts: ' + str(float(sum_total) / total)) | ||
64 | + print('Average final (after sampling) contexts: ' + str(float(sum_sampled) / total)) | ||
65 | + print('Total examples: ' + str(total)) | ||
66 | + print('Empty examples: ' + str(empty)) | ||
67 | + print('Max number of contexts per word: ' + str(max_unfiltered)) | ||
68 | + return total | ||
69 | + | ||
70 | + | ||
71 | +def context_full_found(context_parts, word_to_count, path_to_count): | ||
72 | + return context_parts[0] in word_to_count \ | ||
73 | + and context_parts[1] in path_to_count and context_parts[2] in word_to_count | ||
74 | + | ||
75 | + | ||
76 | +def context_partial_found(context_parts, word_to_count, path_to_count): | ||
77 | + return context_parts[0] in word_to_count \ | ||
78 | + or context_parts[1] in path_to_count or context_parts[2] in word_to_count | ||
79 | + | ||
80 | + | ||
81 | +if __name__ == '__main__': | ||
82 | + | ||
83 | + parser = ArgumentParser() | ||
84 | + parser.add_argument("-trd", "--train_data", dest="train_data_path", | ||
85 | + help="path to training data file", required=True) | ||
86 | + parser.add_argument("-ted", "--test_data", dest="test_data_path", | ||
87 | + help="path to test data file", required=True) | ||
88 | + parser.add_argument("-vd", "--val_data", dest="val_data_path", | ||
89 | + help="path to validation data file", required=True) | ||
90 | + parser.add_argument("-mc", "--max_contexts", dest="max_contexts", default=200, | ||
91 | + help="number of max contexts to keep", required=False) | ||
92 | + parser.add_argument("-wvs", "--word_vocab_size", dest="word_vocab_size", default=1301136, | ||
93 | + help="Max number of origin word in to keep in the vocabulary", required=False) | ||
94 | + parser.add_argument("-pvs", "--path_vocab_size", dest="path_vocab_size", default=911417, | ||
95 | + help="Max number of paths to keep in the vocabulary", required=False) | ||
96 | + parser.add_argument("-tvs", "--target_vocab_size", dest="target_vocab_size", default=261245, | ||
97 | + help="Max number of target words to keep in the vocabulary", required=False) | ||
98 | + parser.add_argument("-wh", "--word_histogram", dest="word_histogram", | ||
99 | + help="word histogram file", metavar="FILE", required=True) | ||
100 | + parser.add_argument("-ph", "--path_histogram", dest="path_histogram", | ||
101 | + help="path_histogram file", metavar="FILE", required=True) | ||
102 | + parser.add_argument("-th", "--target_histogram", dest="target_histogram", | ||
103 | + help="target histogram file", metavar="FILE", required=True) | ||
104 | + parser.add_argument("-o", "--output_name", dest="output_name", | ||
105 | + help="output name - the base name for the created dataset", metavar="FILE", required=True, | ||
106 | + default='data') | ||
107 | + args = parser.parse_args() | ||
108 | + | ||
109 | + train_data_path = args.train_data_path | ||
110 | + test_data_path = args.test_data_path | ||
111 | + val_data_path = args.val_data_path | ||
112 | + word_histogram_path = args.word_histogram | ||
113 | + path_histogram_path = args.path_histogram | ||
114 | + | ||
115 | + word_histogram_data = common.common.load_vocab_from_histogram(word_histogram_path, start_from=1, | ||
116 | + max_size=int(args.word_vocab_size), | ||
117 | + return_counts=True) | ||
118 | + _, _, _, word_to_count = word_histogram_data | ||
119 | + _, _, _, path_to_count = common.common.load_vocab_from_histogram(path_histogram_path, start_from=1, | ||
120 | + max_size=int(args.path_vocab_size), | ||
121 | + return_counts=True) | ||
122 | + _, _, _, target_to_count = common.common.load_vocab_from_histogram(args.target_histogram, start_from=1, | ||
123 | + max_size=int(args.target_vocab_size), | ||
124 | + return_counts=True) | ||
125 | + | ||
126 | + num_training_examples = 0 | ||
127 | + for data_file_path, data_role in zip([test_data_path, val_data_path, train_data_path], ['test', 'val', 'train']): | ||
128 | + num_examples = process_file(file_path=data_file_path, data_file_role=data_role, dataset_name=args.output_name, | ||
129 | + word_to_count=word_to_count, path_to_count=path_to_count, | ||
130 | + max_contexts=int(args.max_contexts)) | ||
131 | + if data_role == 'train': | ||
132 | + num_training_examples = num_examples | ||
133 | + | ||
134 | + save_dictionaries(dataset_name=args.output_name, word_to_count=word_to_count, | ||
135 | + path_to_count=path_to_count, target_to_count=target_to_count, | ||
136 | + num_training_examples=num_training_examples) |
code/code2vec/preprocess_py.sh
0 → 100644
1 | +TRAIN_DIR=dataset_train | ||
2 | +VAL_DIR=dataset_val | ||
3 | +TEST_DIR=dataset_test | ||
4 | +DATASET_NAME=dataset | ||
5 | +MAX_CONTEXTS=200 | ||
6 | +WORD_VOCAB_SIZE=1301136 | ||
7 | +PATH_VOCAB_SIZE=911417 | ||
8 | +TARGET_VOCAB_SIZE=261245 | ||
9 | +NUM_THREADS=64 | ||
10 | +PYTHON=python | ||
11 | +########################################################### | ||
12 | + | ||
13 | +TRAIN_DATA_PATH=data/path_contexts_train.csv | ||
14 | +VAL_DATA_PATH=data/path_contexts_val.csv | ||
15 | +TEST_DATA_PATH=data/path_contexts_test.csv | ||
16 | + | ||
17 | +TRAIN_DATA_FILE=${TRAIN_DATA_PATH} | ||
18 | +VAL_DATA_FILE=${VAL_DATA_PATH} | ||
19 | +TEST_DATA_FILE=${TEST_DATA_PATH} | ||
20 | + | ||
21 | +mkdir -p data | ||
22 | +mkdir -p data/${DATASET_NAME} | ||
23 | + | ||
24 | +TARGET_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2v | ||
25 | +ORIGIN_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2v | ||
26 | +PATH_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.path.c2v | ||
27 | + | ||
28 | +cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} | ||
29 | +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${ORIGIN_HISTOGRAM_FILE} | ||
30 | +cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${PATH_HISTOGRAM_FILE} | ||
31 | + | ||
32 | +DIR=`dirname "$0"` | ||
33 | + | ||
34 | +${PYTHON} ${DIR}/preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ | ||
35 | + --max_contexts ${MAX_CONTEXTS} --word_vocab_size ${WORD_VOCAB_SIZE} --path_vocab_size ${PATH_VOCAB_SIZE} \ | ||
36 | + --target_vocab_size ${TARGET_VOCAB_SIZE} --word_histogram ${ORIGIN_HISTOGRAM_FILE} \ | ||
37 | + --path_histogram ${PATH_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name data/${DATASET_NAME}/${DATASET_NAME} | ||
38 | + | ||
39 | +rm ${TARGET_HISTOGRAM_FILE} ${ORIGIN_HISTOGRAM_FILE} ${PATH_HISTOGRAM_FILE} |
code/code2vec/py_extractor.py
0 → 100644
1 | +import subprocess | ||
2 | + | ||
3 | +class PyExtractor: | ||
4 | + def __init__(self, config): | ||
5 | + self.config = config | ||
6 | + | ||
7 | + def read_file(self, input_filename): | ||
8 | + with open(input_filename, 'r') as file: | ||
9 | + return file.readlines() | ||
10 | + | ||
11 | + def extract_paths(self, path): | ||
12 | + output = self.read_file(path) | ||
13 | + | ||
14 | + if len(output) == 0: | ||
15 | + err = err.decode() | ||
16 | + raise ValueError(err) | ||
17 | + hash_to_string_dict = {} | ||
18 | + result = [] | ||
19 | + for i, line in enumerate(output): | ||
20 | + parts = line.rstrip().split(' ') | ||
21 | + method_name = parts[0] | ||
22 | + current_result_line_parts = [method_name] | ||
23 | + contexts = parts[1:] | ||
24 | + for context in contexts[:self.config.MAX_CONTEXTS]: | ||
25 | + context_parts = context.split(',') | ||
26 | + context_word1 = context_parts[0] | ||
27 | + context_path = context_parts[1] | ||
28 | + context_word2 = context_parts[2] | ||
29 | + hashed_path = str(context_path) | ||
30 | + hash_to_string_dict[hashed_path] = context_path | ||
31 | + current_result_line_parts += ['%s,%s,%s' % (context_word1, hashed_path, context_word2)] | ||
32 | + space_padding = ' ' * (self.config.MAX_CONTEXTS - len(contexts)) | ||
33 | + result_line = ' '.join(current_result_line_parts) + space_padding | ||
34 | + result.append(result_line) | ||
35 | + return result, hash_to_string_dict |
code/code2vec/tensorflow_model.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | +import time | ||
4 | +from typing import Dict, Optional, List, Iterable | ||
5 | +from collections import Counter | ||
6 | +from functools import partial | ||
7 | + | ||
8 | +from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction | ||
9 | +from common import common | ||
10 | +from vocabularies import VocabType | ||
11 | +from config import Config | ||
12 | +from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults | ||
13 | + | ||
14 | + | ||
15 | +tf.compat.v1.disable_eager_execution() | ||
16 | + | ||
17 | + | ||
18 | +class Code2VecModel(Code2VecModelBase): | ||
19 | + def __init__(self, config: Config): | ||
20 | + self.sess = tf.compat.v1.Session() | ||
21 | + self.saver = None | ||
22 | + | ||
23 | + self.eval_reader = None | ||
24 | + self.eval_input_iterator_reset_op = None | ||
25 | + self.predict_reader = None | ||
26 | + self.MAX_BATCH_NUM = 30 | ||
27 | + | ||
28 | + self.predict_placeholder = None | ||
29 | + self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors = None, None, None, None | ||
30 | + self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op = None, None, None | ||
31 | + | ||
32 | + self.vocab_type_to_tf_variable_name_mapping: Dict[VocabType, str] = { | ||
33 | + VocabType.Token: 'WORDS_VOCAB', | ||
34 | + VocabType.Target: 'TARGET_WORDS_VOCAB', | ||
35 | + VocabType.Path: 'PATHS_VOCAB' | ||
36 | + } | ||
37 | + | ||
38 | + super(Code2VecModel, self).__init__(config) | ||
39 | + | ||
40 | + def train(self): | ||
41 | + self.log('Starting training') | ||
42 | + start_time = time.time() | ||
43 | + | ||
44 | + batch_num = 0 | ||
45 | + sum_loss = 0 | ||
46 | + multi_batch_start_time = time.time() | ||
47 | + num_batches_to_save_and_eval = max(int(self.config.train_steps_per_epoch * self.config.SAVE_EVERY_EPOCHS), 1) | ||
48 | + | ||
49 | + train_reader = PathContextReader(vocabs=self.vocabs, | ||
50 | + model_input_tensors_former=_TFTrainModelInputTensorsFormer(), | ||
51 | + config=self.config, estimator_action=EstimatorAction.Train) | ||
52 | + input_iterator = tf.compat.v1.data.make_initializable_iterator(train_reader.get_dataset()) | ||
53 | + input_iterator_reset_op = input_iterator.initializer | ||
54 | + input_tensors = input_iterator.get_next() | ||
55 | + | ||
56 | + optimizer, train_loss = self._build_tf_training_graph(input_tensors) | ||
57 | + self.saver = tf.compat.v1.train.Saver(max_to_keep=self.config.MAX_TO_KEEP) | ||
58 | + | ||
59 | + self.log('Number of trainable params: {}'.format( | ||
60 | + np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()]))) | ||
61 | + for variable in tf.compat.v1.trainable_variables(): | ||
62 | + self.log("variable name: {} -- shape: {} -- #params: {}".format( | ||
63 | + variable.name, variable.get_shape(), np.prod(variable.get_shape().as_list()))) | ||
64 | + | ||
65 | + self._initialize_session_variables() | ||
66 | + | ||
67 | + if self.config.MODEL_LOAD_PATH: | ||
68 | + self._load_inner_model(self.sess) | ||
69 | + | ||
70 | + self.sess.run(input_iterator_reset_op) | ||
71 | + time.sleep(1) | ||
72 | + self.log('Started reader...') | ||
73 | + | ||
74 | + try: | ||
75 | + while batch_num <= self.MAX_BATCH_NUM: | ||
76 | + batch_num += 1 | ||
77 | + | ||
78 | + _, batch_loss = self.sess.run([optimizer, train_loss]) | ||
79 | + | ||
80 | + sum_loss += batch_loss | ||
81 | + if batch_num % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0: | ||
82 | + self._trace_training(sum_loss, batch_num, multi_batch_start_time) | ||
83 | + sum_loss = 0 | ||
84 | + multi_batch_start_time = time.time() | ||
85 | + if batch_num % num_batches_to_save_and_eval == 0: | ||
86 | + epoch_num = int((batch_num / num_batches_to_save_and_eval) * self.config.SAVE_EVERY_EPOCHS) | ||
87 | + model_save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num) | ||
88 | + self.save(model_save_path) | ||
89 | + self.log('Saved after %d epochs in: %s' % (epoch_num, model_save_path)) | ||
90 | + evaluation_results = self.evaluate() | ||
91 | + evaluation_results_str = (str(evaluation_results).replace('topk', 'top{}'.format( | ||
92 | + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION))) | ||
93 | + self.log('After {nr_epochs} epochs -- {evaluation_results}'.format( | ||
94 | + nr_epochs=epoch_num, | ||
95 | + evaluation_results=evaluation_results_str | ||
96 | + )) | ||
97 | + except tf.errors.OutOfRangeError: | ||
98 | + self.log('Session Exhausted during the batch training') | ||
99 | + pass # exhausted | ||
100 | + | ||
101 | + self.log('Done training') | ||
102 | + | ||
103 | + if self.config.MODEL_SAVE_PATH: | ||
104 | + self._save_inner_model(self.config.MODEL_SAVE_PATH) | ||
105 | + self.log('Model saved in file: %s' % self.config.MODEL_SAVE_PATH) | ||
106 | + | ||
107 | + elapsed = int(time.time() - start_time) | ||
108 | + self.log("Training time: %sH:%sM:%sS\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60)) | ||
109 | + | ||
110 | + def evaluate(self) -> Optional[ModelEvaluationResults]: | ||
111 | + eval_start_time = time.time() | ||
112 | + if self.eval_reader is None: | ||
113 | + self.eval_reader = PathContextReader(vocabs=self.vocabs, | ||
114 | + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(), | ||
115 | + config=self.config, estimator_action=EstimatorAction.Evaluate) | ||
116 | + input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset()) | ||
117 | + self.eval_input_iterator_reset_op = input_iterator.initializer | ||
118 | + input_tensors = input_iterator.get_next() | ||
119 | + | ||
120 | + self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _, \ | ||
121 | + self.eval_code_vectors = self._build_tf_test_graph(input_tensors) | ||
122 | + if self.saver is None: | ||
123 | + self.saver = tf.compat.v1.train.Saver() | ||
124 | + | ||
125 | + if self.config.MODEL_LOAD_PATH and not self.config.TRAIN_DATA_PATH_PREFIX: | ||
126 | + self._initialize_session_variables() | ||
127 | + self._load_inner_model(self.sess) | ||
128 | + if self.config.RELEASE: | ||
129 | + release_name = self.config.MODEL_LOAD_PATH + '.release' | ||
130 | + self.log('Releasing model, output model: %s' % release_name) | ||
131 | + self.saver.save(self.sess, release_name) | ||
132 | + return None | ||
133 | + | ||
134 | + with open('log.txt', 'w') as log_output_file: | ||
135 | + if self.config.EXPORT_CODE_VECTORS: | ||
136 | + code_vectors_file = open(self.config.TEST_DATA_PATH + '.vectors', 'w') | ||
137 | + total_predictions = 0 | ||
138 | + total_prediction_batches = 0 | ||
139 | + subtokens_evaluation_metric = SubtokensEvaluationMetric( | ||
140 | + partial(common.filter_impossible_names, self.vocabs.target_vocab.special_words)) | ||
141 | + topk_accuracy_evaluation_metric = TopKAccuracyEvaluationMetric( | ||
142 | + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION, | ||
143 | + partial(common.get_first_match_word_from_top_predictions, self.vocabs.target_vocab.special_words)) | ||
144 | + start_time = time.time() | ||
145 | + | ||
146 | + self.sess.run(self.eval_input_iterator_reset_op) | ||
147 | + | ||
148 | + self.log('Starting evaluation') | ||
149 | + | ||
150 | + batch_num = 0 | ||
151 | + try: | ||
152 | + while batch_num <= self.MAX_BATCH_NUM: | ||
153 | + batch_num += 1 | ||
154 | + | ||
155 | + top_words, top_scores, original_names, code_vectors = self.sess.run( | ||
156 | + [self.eval_top_words_op, self.eval_top_values_op, | ||
157 | + self.eval_original_names_op, self.eval_code_vectors], | ||
158 | + ) | ||
159 | + | ||
160 | + top_words = common.binary_to_string_matrix(top_words) # (batch, top_k) | ||
161 | + original_names = common.binary_to_string_list(original_names) # (batch,) | ||
162 | + | ||
163 | + self._log_predictions_during_evaluation(zip(original_names, top_words), log_output_file) | ||
164 | + topk_accuracy_evaluation_metric.update_batch(zip(original_names, top_words)) | ||
165 | + subtokens_evaluation_metric.update_batch(zip(original_names, top_words)) | ||
166 | + | ||
167 | + total_predictions += len(original_names) | ||
168 | + total_prediction_batches += 1 | ||
169 | + if self.config.EXPORT_CODE_VECTORS: | ||
170 | + self._write_code_vectors(code_vectors_file, code_vectors) | ||
171 | + if total_prediction_batches % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0: | ||
172 | + elapsed = time.time() - start_time | ||
173 | + self._trace_evaluation(total_predictions, elapsed) | ||
174 | + | ||
175 | + except tf.errors.OutOfRangeError: | ||
176 | + self.log('Session Exhausted during the batch evaluating') | ||
177 | + pass | ||
178 | + | ||
179 | + self.log('Done evaluating, epoch reached') | ||
180 | + log_output_file.write(str(topk_accuracy_evaluation_metric.topk_correct_predictions) + '\n') | ||
181 | + | ||
182 | + if self.config.EXPORT_CODE_VECTORS: | ||
183 | + code_vectors_file.close() | ||
184 | + | ||
185 | + elapsed = int(time.time() - eval_start_time) | ||
186 | + self.log("Evaluation time: %sH:%sM:%sS" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60)) | ||
187 | + return ModelEvaluationResults( | ||
188 | + topk_acc=topk_accuracy_evaluation_metric.topk_correct_predictions, | ||
189 | + subtoken_precision=subtokens_evaluation_metric.precision, | ||
190 | + subtoken_recall=subtokens_evaluation_metric.recall, | ||
191 | + subtoken_f1=subtokens_evaluation_metric.f1) | ||
192 | + | ||
193 | + def _build_tf_training_graph(self, input_tensors): | ||
194 | + input_tensors = _TFTrainModelInputTensorsFormer().from_model_input_form(input_tensors) | ||
195 | + | ||
196 | + with tf.compat.v1.variable_scope('model'): | ||
197 | + tokens_vocab = tf.compat.v1.get_variable( | ||
198 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Token], | ||
199 | + shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE), dtype=tf.float32, | ||
200 | + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform")) | ||
201 | + targets_vocab = tf.compat.v1.get_variable( | ||
202 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Target], | ||
203 | + shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE), dtype=tf.float32, | ||
204 | + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform")) | ||
205 | + attention_param = tf.compat.v1.get_variable( | ||
206 | + 'ATTENTION', | ||
207 | + shape=(self.config.CODE_VECTOR_SIZE, 1), dtype=tf.float32) | ||
208 | + paths_vocab = tf.compat.v1.get_variable( | ||
209 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Path], | ||
210 | + shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE), dtype=tf.float32, | ||
211 | + initializer=tf.compat.v1.initializers.variance_scaling(scale=1.0, mode='fan_out', distribution="uniform")) | ||
212 | + | ||
213 | + code_vectors, _ = self._calculate_weighted_contexts( | ||
214 | + tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices, | ||
215 | + input_tensors.path_indices, input_tensors.path_target_token_indices, input_tensors.context_valid_mask) | ||
216 | + | ||
217 | + logits = tf.matmul(code_vectors, targets_vocab, transpose_b=True) | ||
218 | + batch_size = tf.cast(tf.shape(input_tensors.target_index)[0], dtype=tf.float32) | ||
219 | + loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
220 | + labels=tf.reshape(input_tensors.target_index, [-1]), | ||
221 | + logits=logits)) / batch_size | ||
222 | + | ||
223 | + optimizer = tf.compat.v1.train.AdamOptimizer().minimize(loss) | ||
224 | + | ||
225 | + return optimizer, loss | ||
226 | + | ||
227 | + def _calculate_weighted_contexts(self, tokens_vocab, paths_vocab, attention_param, source_input, path_input, | ||
228 | + target_input, valid_mask, is_evaluating=False): | ||
229 | + source_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=source_input) | ||
230 | + path_embed = tf.nn.embedding_lookup(params=paths_vocab, ids=path_input) | ||
231 | + target_word_embed = tf.nn.embedding_lookup(params=tokens_vocab, ids=target_input) | ||
232 | + | ||
233 | + context_embed = tf.concat([source_word_embed, path_embed, target_word_embed], | ||
234 | + axis=-1) | ||
235 | + | ||
236 | + if not is_evaluating: | ||
237 | + context_embed = tf.nn.dropout(context_embed, rate=1-self.config.DROPOUT_KEEP_RATE) | ||
238 | + | ||
239 | + flat_embed = tf.reshape(context_embed, [-1, self.config.context_vector_size]) | ||
240 | + transform_param = tf.compat.v1.get_variable( | ||
241 | + 'TRANSFORM', shape=(self.config.context_vector_size, self.config.CODE_VECTOR_SIZE), dtype=tf.float32) | ||
242 | + | ||
243 | + flat_embed = tf.tanh(tf.matmul(flat_embed, transform_param)) | ||
244 | + | ||
245 | + contexts_weights = tf.matmul(flat_embed, attention_param) | ||
246 | + batched_contexts_weights = tf.reshape( | ||
247 | + contexts_weights, [-1, self.config.MAX_CONTEXTS, 1]) | ||
248 | + mask = tf.math.log(valid_mask) | ||
249 | + mask = tf.expand_dims(mask, axis=2) | ||
250 | + batched_contexts_weights += mask | ||
251 | + attention_weights = tf.nn.softmax(batched_contexts_weights, axis=1) | ||
252 | + | ||
253 | + batched_embed = tf.reshape(flat_embed, shape=[-1, self.config.MAX_CONTEXTS, self.config.CODE_VECTOR_SIZE]) | ||
254 | + code_vectors = tf.reduce_sum(tf.multiply(batched_embed, attention_weights), axis=1) | ||
255 | + | ||
256 | + return code_vectors, attention_weights | ||
257 | + | ||
258 | + def _build_tf_test_graph(self, input_tensors, normalize_scores=False): | ||
259 | + with tf.compat.v1.variable_scope('model', reuse=self.get_should_reuse_variables()): | ||
260 | + tokens_vocab = tf.compat.v1.get_variable( | ||
261 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Token], | ||
262 | + shape=(self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE), | ||
263 | + dtype=tf.float32, trainable=False) | ||
264 | + targets_vocab = tf.compat.v1.get_variable( | ||
265 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Target], | ||
266 | + shape=(self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE), | ||
267 | + dtype=tf.float32, trainable=False) | ||
268 | + attention_param = tf.compat.v1.get_variable( | ||
269 | + 'ATTENTION', shape=(self.config.context_vector_size, 1), | ||
270 | + dtype=tf.float32, trainable=False) | ||
271 | + paths_vocab = tf.compat.v1.get_variable( | ||
272 | + self.vocab_type_to_tf_variable_name_mapping[VocabType.Path], | ||
273 | + shape=(self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE), | ||
274 | + dtype=tf.float32, trainable=False) | ||
275 | + | ||
276 | + targets_vocab = tf.transpose(targets_vocab) | ||
277 | + | ||
278 | + input_tensors = _TFEvaluateModelInputTensorsFormer().from_model_input_form(input_tensors) | ||
279 | + | ||
280 | + code_vectors, attention_weights = self._calculate_weighted_contexts( | ||
281 | + tokens_vocab, paths_vocab, attention_param, input_tensors.path_source_token_indices, | ||
282 | + input_tensors.path_indices, input_tensors.path_target_token_indices, | ||
283 | + input_tensors.context_valid_mask, is_evaluating=True) | ||
284 | + | ||
285 | + scores = tf.matmul(code_vectors, targets_vocab) # (batch, target_word_vocab) | ||
286 | + | ||
287 | + topk_candidates = tf.nn.top_k(scores, k=tf.minimum( | ||
288 | + self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION, self.vocabs.target_vocab.size)) | ||
289 | + top_indices = topk_candidates.indices | ||
290 | + top_words = self.vocabs.target_vocab.lookup_word(top_indices) | ||
291 | + original_words = input_tensors.target_string | ||
292 | + top_scores = topk_candidates.values | ||
293 | + if normalize_scores: | ||
294 | + top_scores = tf.nn.softmax(top_scores) | ||
295 | + | ||
296 | + return top_words, top_scores, original_words, attention_weights, input_tensors.path_source_token_strings, \ | ||
297 | + input_tensors.path_strings, input_tensors.path_target_token_strings, code_vectors | ||
298 | + | ||
299 | + def predict(self, predict_data_lines: Iterable[str]) -> List[ModelPredictionResults]: | ||
300 | + if self.predict_reader is None: | ||
301 | + self.predict_reader = PathContextReader(vocabs=self.vocabs, | ||
302 | + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(), | ||
303 | + config=self.config, estimator_action=EstimatorAction.Predict) | ||
304 | + self.predict_placeholder = tf.compat.v1.placeholder(tf.string) | ||
305 | + reader_output = self.predict_reader.process_input_row(self.predict_placeholder) | ||
306 | + | ||
307 | + self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, \ | ||
308 | + self.attention_weights_op, self.predict_source_string, self.predict_path_string, \ | ||
309 | + self.predict_path_target_string, self.predict_code_vectors = \ | ||
310 | + self._build_tf_test_graph(reader_output, normalize_scores=True) | ||
311 | + | ||
312 | + self._initialize_session_variables() | ||
313 | + self.saver = tf.compat.v1.train.Saver() | ||
314 | + self._load_inner_model(sess=self.sess) | ||
315 | + | ||
316 | + prediction_results: List[ModelPredictionResults] = [] | ||
317 | + for line in predict_data_lines: | ||
318 | + batch_top_words, batch_top_scores, batch_original_name, batch_attention_weights, batch_path_source_strings,\ | ||
319 | + batch_path_strings, batch_path_target_strings, batch_code_vectors = self.sess.run( | ||
320 | + [self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, | ||
321 | + self.attention_weights_op, self.predict_source_string, self.predict_path_string, | ||
322 | + self.predict_path_target_string, self.predict_code_vectors], | ||
323 | + feed_dict={self.predict_placeholder: line}) | ||
324 | + | ||
325 | + assert all(tensor.shape[0] == 1 for tensor in (batch_top_words, batch_top_scores, batch_original_name, | ||
326 | + batch_attention_weights, batch_path_source_strings, | ||
327 | + batch_path_strings, batch_path_target_strings, | ||
328 | + batch_code_vectors)) | ||
329 | + top_words = np.squeeze(batch_top_words, axis=0) | ||
330 | + top_scores = np.squeeze(batch_top_scores, axis=0) | ||
331 | + original_name = batch_original_name[0] | ||
332 | + attention_weights = np.squeeze(batch_attention_weights, axis=0) | ||
333 | + path_source_strings = np.squeeze(batch_path_source_strings, axis=0) | ||
334 | + path_strings = np.squeeze(batch_path_strings, axis=0) | ||
335 | + path_target_strings = np.squeeze(batch_path_target_strings, axis=0) | ||
336 | + code_vectors = np.squeeze(batch_code_vectors, axis=0) | ||
337 | + | ||
338 | + top_words = common.binary_to_string_list(top_words) | ||
339 | + original_name = common.binary_to_string(original_name) | ||
340 | + attention_per_context = self._get_attention_weight_per_context( | ||
341 | + path_source_strings, path_strings, path_target_strings, attention_weights) | ||
342 | + prediction_results.append(ModelPredictionResults( | ||
343 | + original_name=original_name, | ||
344 | + topk_predicted_words=top_words, | ||
345 | + topk_predicted_words_scores=top_scores, | ||
346 | + attention_per_context=attention_per_context, | ||
347 | + code_vector=(code_vectors if self.config.EXPORT_CODE_VECTORS else None) | ||
348 | + )) | ||
349 | + return prediction_results | ||
350 | + | ||
351 | + def _save_inner_model(self, path: str): | ||
352 | + self.saver.save(self.sess, path) | ||
353 | + | ||
354 | + def _load_inner_model(self, sess=None): | ||
355 | + if sess is not None: | ||
356 | + self.log('Loading model weights from: ' + self.config.MODEL_LOAD_PATH) | ||
357 | + self.saver.restore(sess, self.config.MODEL_LOAD_PATH) | ||
358 | + self.log('Done loading model weights') | ||
359 | + | ||
360 | + def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray: | ||
361 | + assert vocab_type in VocabType | ||
362 | + vocab_tf_variable_name = self.vocab_type_to_tf_variable_name_mapping[vocab_type] | ||
363 | + | ||
364 | + if self.eval_reader is None: | ||
365 | + self.eval_reader = PathContextReader(vocabs=self.vocabs, | ||
366 | + model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(), | ||
367 | + config=self.config, estimator_action=EstimatorAction.Evaluate) | ||
368 | + input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset()) | ||
369 | + _, _, _, _, _, _, _, _ = self._build_tf_test_graph(input_iterator.get_next()) | ||
370 | + | ||
371 | + if vocab_type is VocabType.Token: | ||
372 | + shape = (self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE) | ||
373 | + elif vocab_type is VocabType.Target: | ||
374 | + shape = (self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE) | ||
375 | + elif vocab_type is VocabType.Path: | ||
376 | + shape = (self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE) | ||
377 | + | ||
378 | + with tf.compat.v1.variable_scope('model', reuse=True): | ||
379 | + embeddings = tf.compat.v1.get_variable(vocab_tf_variable_name, shape=shape) | ||
380 | + self.saver = tf.compat.v1.train.Saver() | ||
381 | + self._initialize_session_variables() | ||
382 | + self._load_inner_model(self.sess) | ||
383 | + vocab_embedding_matrix = self.sess.run(embeddings) | ||
384 | + return vocab_embedding_matrix | ||
385 | + | ||
386 | + def get_should_reuse_variables(self): | ||
387 | + if self.config.TRAIN_DATA_PATH_PREFIX: | ||
388 | + return True | ||
389 | + else: | ||
390 | + return None | ||
391 | + | ||
392 | + def _log_predictions_during_evaluation(self, results, output_file): | ||
393 | + for original_name, top_predicted_words in results: | ||
394 | + found_match = common.get_first_match_word_from_top_predictions( | ||
395 | + self.vocabs.target_vocab.special_words, original_name, top_predicted_words) | ||
396 | + if found_match is not None: | ||
397 | + prediction_idx, predicted_word = found_match | ||
398 | + if prediction_idx == 0: | ||
399 | + output_file.write('Original: ' + original_name + ', predicted 1st: ' + predicted_word + '\n') | ||
400 | + else: | ||
401 | + output_file.write('\t\t predicted correctly at rank: ' + str(prediction_idx + 1) + '\n') | ||
402 | + else: | ||
403 | + output_file.write('No results for predicting: ' + original_name) | ||
404 | + | ||
405 | + def _trace_training(self, sum_loss, batch_num, multi_batch_start_time): | ||
406 | + multi_batch_elapsed = time.time() - multi_batch_start_time | ||
407 | + avg_loss = sum_loss / (self.config.NUM_BATCHES_TO_LOG_PROGRESS * self.config.TRAIN_BATCH_SIZE) | ||
408 | + throughput = self.config.TRAIN_BATCH_SIZE * self.config.NUM_BATCHES_TO_LOG_PROGRESS / \ | ||
409 | + (multi_batch_elapsed if multi_batch_elapsed > 0 else 1) | ||
410 | + self.log('Average loss at batch %d: %f, \tthroughput: %d samples/sec' % ( | ||
411 | + batch_num, avg_loss, throughput)) | ||
412 | + | ||
413 | + def _trace_evaluation(self, total_predictions, elapsed): | ||
414 | + state_message = 'Evaluated %d examples...' % total_predictions | ||
415 | + throughput_message = "Prediction throughput: %d samples/sec" % int( | ||
416 | + total_predictions / (elapsed if elapsed > 0 else 1)) | ||
417 | + self.log(state_message) | ||
418 | + self.log(throughput_message) | ||
419 | + | ||
420 | + def close_session(self): | ||
421 | + self.sess.close() | ||
422 | + | ||
423 | + def _initialize_session_variables(self): | ||
424 | + self.sess.run(tf.group( | ||
425 | + tf.compat.v1.global_variables_initializer(), | ||
426 | + tf.compat.v1.local_variables_initializer(), | ||
427 | + tf.compat.v1.tables_initializer())) | ||
428 | + self.log('Initalized variables') | ||
429 | + | ||
430 | + | ||
431 | +class SubtokensEvaluationMetric: | ||
432 | + def __init__(self, filter_impossible_names_fn): | ||
433 | + self.nr_true_positives: int = 0 | ||
434 | + self.nr_false_positives: int = 0 | ||
435 | + self.nr_false_negatives: int = 0 | ||
436 | + self.nr_predictions: int = 0 | ||
437 | + self.filter_impossible_names_fn = filter_impossible_names_fn | ||
438 | + | ||
439 | + def update_batch(self, results): | ||
440 | + for original_name, top_words in results: | ||
441 | + try: | ||
442 | + possible_names = self.filter_impossible_names_fn(top_words) | ||
443 | + prediction = possible_names[0] | ||
444 | + original_subtokens = Counter(common.get_subtokens(original_name)) | ||
445 | + predicted_subtokens = Counter(common.get_subtokens(prediction)) | ||
446 | + self.nr_true_positives += sum(count for element, count in predicted_subtokens.items() | ||
447 | + if element in original_subtokens) | ||
448 | + self.nr_false_positives += sum(count for element, count in predicted_subtokens.items() | ||
449 | + if element not in original_subtokens) | ||
450 | + self.nr_false_negatives += sum(count for element, count in original_subtokens.items() | ||
451 | + if element not in predicted_subtokens) | ||
452 | + self.nr_predictions += 1 | ||
453 | + except Exception as e: | ||
454 | + print(e) | ||
455 | + print("List Length:", len(test)) | ||
456 | + for p in test: | ||
457 | + print(p, end=' ') | ||
458 | + print('') | ||
459 | + print("Top Words:", top_words) | ||
460 | + raise | ||
461 | + | ||
462 | + @property | ||
463 | + def true_positive(self): | ||
464 | + return self.nr_true_positives / self.nr_predictions | ||
465 | + | ||
466 | + @property | ||
467 | + def false_positive(self): | ||
468 | + return self.nr_false_positives / self.nr_predictions | ||
469 | + | ||
470 | + @property | ||
471 | + def false_negative(self): | ||
472 | + return self.nr_false_negatives / self.nr_predictions | ||
473 | + | ||
474 | + @property | ||
475 | + def precision(self): | ||
476 | + return self.nr_true_positives / (self.nr_true_positives + self.nr_false_positives) | ||
477 | + | ||
478 | + @property | ||
479 | + def recall(self): | ||
480 | + return self.nr_true_positives / (self.nr_true_positives + self.nr_false_negatives) | ||
481 | + | ||
482 | + @property | ||
483 | + def f1(self): | ||
484 | + return 2 * self.precision * self.recall / (self.precision + self.recall) | ||
485 | + | ||
486 | + | ||
487 | +class TopKAccuracyEvaluationMetric: | ||
488 | + def __init__(self, top_k: int, get_first_match_word_from_top_predictions_fn): | ||
489 | + self.top_k = top_k | ||
490 | + self.nr_correct_predictions = np.zeros(self.top_k) | ||
491 | + self.nr_predictions: int = 0 | ||
492 | + self.get_first_match_word_from_top_predictions_fn = get_first_match_word_from_top_predictions_fn | ||
493 | + | ||
494 | + def update_batch(self, results): | ||
495 | + for original_name, top_predicted_words in results: | ||
496 | + self.nr_predictions += 1 | ||
497 | + found_match = self.get_first_match_word_from_top_predictions_fn(original_name, top_predicted_words) | ||
498 | + if found_match is not None: | ||
499 | + suggestion_idx, _ = found_match | ||
500 | + self.nr_correct_predictions[suggestion_idx:self.top_k] += 1 | ||
501 | + | ||
502 | + @property | ||
503 | + def topk_correct_predictions(self): | ||
504 | + return self.nr_correct_predictions / self.nr_predictions | ||
505 | + | ||
506 | + | ||
507 | +class _TFTrainModelInputTensorsFormer(ModelInputTensorsFormer): | ||
508 | + def to_model_input_form(self, input_tensors: ReaderInputTensors): | ||
509 | + return input_tensors.target_index, input_tensors.path_source_token_indices, input_tensors.path_indices, \ | ||
510 | + input_tensors.path_target_token_indices, input_tensors.context_valid_mask | ||
511 | + | ||
512 | + def from_model_input_form(self, input_row) -> ReaderInputTensors: | ||
513 | + return ReaderInputTensors( | ||
514 | + target_index=input_row[0], | ||
515 | + path_source_token_indices=input_row[1], | ||
516 | + path_indices=input_row[2], | ||
517 | + path_target_token_indices=input_row[3], | ||
518 | + context_valid_mask=input_row[4] | ||
519 | + ) | ||
520 | + | ||
521 | + | ||
522 | +class _TFEvaluateModelInputTensorsFormer(ModelInputTensorsFormer): | ||
523 | + def to_model_input_form(self, input_tensors: ReaderInputTensors): | ||
524 | + return (input_tensors.target_string, input_tensors.path_source_token_indices, input_tensors.path_indices, | ||
525 | + input_tensors.path_target_token_indices, input_tensors.context_valid_mask, | ||
526 | + input_tensors.path_source_token_strings, input_tensors.path_strings, | ||
527 | + input_tensors.path_target_token_strings) | ||
528 | + | ||
529 | + def from_model_input_form(self, input_row) -> ReaderInputTensors: | ||
530 | + return ReaderInputTensors( | ||
531 | + target_string=input_row[0], | ||
532 | + path_source_token_indices=input_row[1], | ||
533 | + path_indices=input_row[2], | ||
534 | + path_target_token_indices=input_row[3], | ||
535 | + context_valid_mask=input_row[4], | ||
536 | + path_source_token_strings=input_row[5], | ||
537 | + path_strings=input_row[6], | ||
538 | + path_target_token_strings=input_row[7] | ||
539 | + ) |
code/code2vec/train.sh
0 → 100644
1 | +type=python | ||
2 | +dataset_name=dataset | ||
3 | +data_dir=../data/${dataset_name} | ||
4 | +data=${data_dir}/${dataset_name} | ||
5 | +test_data=${data_dir}/${dataset_name}.val.c2v | ||
6 | +model_dir=models/${type} | ||
7 | + | ||
8 | +mkdir -p ${model_dir} | ||
9 | +set -e | ||
10 | +python -u code2vec.py --data ${data} --save ${model_dir}/saved_model --test ${test_data} |
code/code2vec/vocabularies.py
0 → 100644
1 | +from itertools import chain | ||
2 | +from typing import Optional, Dict, Iterable, Set, NamedTuple | ||
3 | +import pickle | ||
4 | +import os | ||
5 | +from enum import Enum | ||
6 | +from config import Config | ||
7 | +import tensorflow as tf | ||
8 | +from argparse import Namespace | ||
9 | + | ||
10 | +from common import common | ||
11 | + | ||
12 | + | ||
13 | +class VocabType(Enum): | ||
14 | + Token = 1 | ||
15 | + Target = 2 | ||
16 | + Path = 3 | ||
17 | + | ||
18 | + | ||
19 | +SpecialVocabWordsType = Namespace | ||
20 | + | ||
21 | + | ||
22 | +_SpecialVocabWords_OnlyOov = Namespace( | ||
23 | + OOV='<OOV>' | ||
24 | +) | ||
25 | + | ||
26 | +_SpecialVocabWords_SeparateOovPad = Namespace( | ||
27 | + PAD='<PAD>', | ||
28 | + OOV='<OOV>' | ||
29 | +) | ||
30 | + | ||
31 | +_SpecialVocabWords_JoinedOovPad = Namespace( | ||
32 | + PAD_OR_OOV='<PAD_OR_OOV>', | ||
33 | + PAD='<PAD_OR_OOV>', | ||
34 | + OOV='<PAD_OR_OOV>' | ||
35 | +) | ||
36 | + | ||
37 | + | ||
38 | +class Vocab: | ||
39 | + def __init__(self, vocab_type: VocabType, words: Iterable[str], | ||
40 | + special_words: Optional[SpecialVocabWordsType] = None): | ||
41 | + if special_words is None: | ||
42 | + special_words = Namespace() | ||
43 | + | ||
44 | + self.vocab_type = vocab_type | ||
45 | + self.word_to_index: Dict[str, int] = {} | ||
46 | + self.index_to_word: Dict[int, str] = {} | ||
47 | + self._word_to_index_lookup_table = None | ||
48 | + self._index_to_word_lookup_table = None | ||
49 | + self.special_words: SpecialVocabWordsType = special_words | ||
50 | + | ||
51 | + for index, word in enumerate(chain(common.get_unique_list(special_words.__dict__.values()), words)): | ||
52 | + self.word_to_index[word] = index | ||
53 | + self.index_to_word[index] = word | ||
54 | + | ||
55 | + self.size = len(self.word_to_index) | ||
56 | + | ||
57 | + def save_to_file(self, file): | ||
58 | + special_words_as_unique_list = common.get_unique_list(self.special_words.__dict__.values()) | ||
59 | + nr_special_words = len(special_words_as_unique_list) | ||
60 | + word_to_index_wo_specials = {word: idx for word, idx in self.word_to_index.items() if idx >= nr_special_words} | ||
61 | + index_to_word_wo_specials = {idx: word for idx, word in self.index_to_word.items() if idx >= nr_special_words} | ||
62 | + size_wo_specials = self.size - nr_special_words | ||
63 | + pickle.dump(word_to_index_wo_specials, file) | ||
64 | + pickle.dump(index_to_word_wo_specials, file) | ||
65 | + pickle.dump(size_wo_specials, file) | ||
66 | + | ||
67 | + @classmethod | ||
68 | + def load_from_file(cls, vocab_type: VocabType, file, special_words: SpecialVocabWordsType) -> 'Vocab': | ||
69 | + special_words_as_unique_list = common.get_unique_list(special_words.__dict__.values()) | ||
70 | + | ||
71 | + word_to_index_wo_specials = pickle.load(file) | ||
72 | + index_to_word_wo_specials = pickle.load(file) | ||
73 | + size_wo_specials = pickle.load(file) | ||
74 | + assert len(index_to_word_wo_specials) == len(word_to_index_wo_specials) == size_wo_specials | ||
75 | + min_word_idx_wo_specials = min(index_to_word_wo_specials.keys()) | ||
76 | + | ||
77 | + if min_word_idx_wo_specials != len(special_words_as_unique_list): | ||
78 | + raise ValueError( | ||
79 | + "Error while attempting to load vocabulary `{vocab_type}` from file `{file_path}`. " | ||
80 | + "The stored vocabulary has minimum word index {min_word_idx}, " | ||
81 | + "while expecting minimum word index to be {nr_special_words} " | ||
82 | + "because having to use {nr_special_words} special words, which are: {special_words}. " | ||
83 | + "Please check the parameter `config.SEPARATE_OOV_AND_PAD`.".format( | ||
84 | + vocab_type=vocab_type, file_path=file.name, min_word_idx=min_word_idx_wo_specials, | ||
85 | + nr_special_words=len(special_words_as_unique_list), special_words=special_words)) | ||
86 | + | ||
87 | + vocab = cls(vocab_type, [], special_words) | ||
88 | + vocab.word_to_index = {**word_to_index_wo_specials, | ||
89 | + **{word: idx for idx, word in enumerate(special_words_as_unique_list)}} | ||
90 | + vocab.index_to_word = {**index_to_word_wo_specials, | ||
91 | + **{idx: word for idx, word in enumerate(special_words_as_unique_list)}} | ||
92 | + vocab.size = size_wo_specials + len(special_words_as_unique_list) | ||
93 | + return vocab | ||
94 | + | ||
95 | + @classmethod | ||
96 | + def create_from_freq_dict(cls, vocab_type: VocabType, word_to_count: Dict[str, int], max_size: int, | ||
97 | + special_words: Optional[SpecialVocabWordsType] = None): | ||
98 | + if special_words is None: | ||
99 | + special_words = Namespace() | ||
100 | + words_sorted_by_counts = sorted(word_to_count, key=word_to_count.get, reverse=True) | ||
101 | + words_sorted_by_counts_and_limited = words_sorted_by_counts[:max_size] | ||
102 | + return cls(vocab_type, words_sorted_by_counts_and_limited, special_words) | ||
103 | + | ||
104 | + @staticmethod | ||
105 | + def _create_word_to_index_lookup_table(word_to_index: Dict[str, int], default_value: int): | ||
106 | + return tf.lookup.StaticHashTable( | ||
107 | + tf.lookup.KeyValueTensorInitializer( | ||
108 | + list(word_to_index.keys()), list(word_to_index.values()), key_dtype=tf.string, value_dtype=tf.int32), | ||
109 | + default_value=tf.constant(default_value, dtype=tf.int32)) | ||
110 | + | ||
111 | + @staticmethod | ||
112 | + def _create_index_to_word_lookup_table(index_to_word: Dict[int, str], default_value: str) \ | ||
113 | + -> tf.lookup.StaticHashTable: | ||
114 | + return tf.lookup.StaticHashTable( | ||
115 | + tf.lookup.KeyValueTensorInitializer( | ||
116 | + list(index_to_word.keys()), list(index_to_word.values()), key_dtype=tf.int32, value_dtype=tf.string), | ||
117 | + default_value=tf.constant(default_value, dtype=tf.string)) | ||
118 | + | ||
119 | + def get_word_to_index_lookup_table(self) -> tf.lookup.StaticHashTable: | ||
120 | + if self._word_to_index_lookup_table is None: | ||
121 | + self._word_to_index_lookup_table = self._create_word_to_index_lookup_table( | ||
122 | + self.word_to_index, default_value=self.word_to_index[self.special_words.OOV]) | ||
123 | + return self._word_to_index_lookup_table | ||
124 | + | ||
125 | + def get_index_to_word_lookup_table(self) -> tf.lookup.StaticHashTable: | ||
126 | + if self._index_to_word_lookup_table is None: | ||
127 | + self._index_to_word_lookup_table = self._create_index_to_word_lookup_table( | ||
128 | + self.index_to_word, default_value=self.special_words.OOV) | ||
129 | + return self._index_to_word_lookup_table | ||
130 | + | ||
131 | + def lookup_index(self, word: tf.Tensor) -> tf.Tensor: | ||
132 | + return self.get_word_to_index_lookup_table().lookup(word) | ||
133 | + | ||
134 | + def lookup_word(self, index: tf.Tensor) -> tf.Tensor: | ||
135 | + return self.get_index_to_word_lookup_table().lookup(index) | ||
136 | + | ||
137 | + | ||
138 | +WordFreqDictType = Dict[str, int] | ||
139 | + | ||
140 | + | ||
141 | +class Code2VecWordFreqDicts(NamedTuple): | ||
142 | + token_to_count: WordFreqDictType | ||
143 | + path_to_count: WordFreqDictType | ||
144 | + target_to_count: WordFreqDictType | ||
145 | + | ||
146 | + | ||
147 | +class Code2VecVocabs: | ||
148 | + def __init__(self, config: Config): | ||
149 | + self.config = config | ||
150 | + self.token_vocab: Optional[Vocab] = None | ||
151 | + self.path_vocab: Optional[Vocab] = None | ||
152 | + self.target_vocab: Optional[Vocab] = None | ||
153 | + | ||
154 | + self._already_saved_in_paths: Set[str] = set() | ||
155 | + | ||
156 | + self._load_or_create() | ||
157 | + | ||
158 | + def _load_or_create(self): | ||
159 | + assert self.config.is_training or self.config.is_loading | ||
160 | + if self.config.is_loading: | ||
161 | + vocabularies_load_path = self.config.get_vocabularies_path_from_model_path(self.config.MODEL_LOAD_PATH) | ||
162 | + if not os.path.isfile(vocabularies_load_path): | ||
163 | + raise ValueError( | ||
164 | + "Model dictionaries file is not found in model load dir. " | ||
165 | + "Expecting file `{vocabularies_load_path}`.".format(vocabularies_load_path=vocabularies_load_path)) | ||
166 | + self._load_from_path(vocabularies_load_path) | ||
167 | + else: | ||
168 | + self._create_from_word_freq_dict() | ||
169 | + | ||
170 | + def _load_from_path(self, vocabularies_load_path: str): | ||
171 | + assert os.path.exists(vocabularies_load_path) | ||
172 | + self.config.log('Loading model vocabularies from: `%s` ... ' % vocabularies_load_path) | ||
173 | + with open(vocabularies_load_path, 'rb') as file: | ||
174 | + self.token_vocab = Vocab.load_from_file( | ||
175 | + VocabType.Token, file, self._get_special_words_by_vocab_type(VocabType.Token)) | ||
176 | + self.target_vocab = Vocab.load_from_file( | ||
177 | + VocabType.Target, file, self._get_special_words_by_vocab_type(VocabType.Target)) | ||
178 | + self.path_vocab = Vocab.load_from_file( | ||
179 | + VocabType.Path, file, self._get_special_words_by_vocab_type(VocabType.Path)) | ||
180 | + self.config.log('Done loading model vocabularies.') | ||
181 | + self._already_saved_in_paths.add(vocabularies_load_path) | ||
182 | + | ||
183 | + def _create_from_word_freq_dict(self): | ||
184 | + word_freq_dict = self._load_word_freq_dict() | ||
185 | + self.config.log('Word frequencies dictionaries loaded. Now creating vocabularies.') | ||
186 | + self.token_vocab = Vocab.create_from_freq_dict( | ||
187 | + VocabType.Token, word_freq_dict.token_to_count, self.config.MAX_TOKEN_VOCAB_SIZE, | ||
188 | + special_words=self._get_special_words_by_vocab_type(VocabType.Token)) | ||
189 | + self.config.log('Created token vocab. size: %d' % self.token_vocab.size) | ||
190 | + self.path_vocab = Vocab.create_from_freq_dict( | ||
191 | + VocabType.Path, word_freq_dict.path_to_count, self.config.MAX_PATH_VOCAB_SIZE, | ||
192 | + special_words=self._get_special_words_by_vocab_type(VocabType.Path)) | ||
193 | + self.config.log('Created path vocab. size: %d' % self.path_vocab.size) | ||
194 | + self.target_vocab = Vocab.create_from_freq_dict( | ||
195 | + VocabType.Target, word_freq_dict.target_to_count, self.config.MAX_TARGET_VOCAB_SIZE, | ||
196 | + special_words=self._get_special_words_by_vocab_type(VocabType.Target)) | ||
197 | + self.config.log('Created target vocab. size: %d' % self.target_vocab.size) | ||
198 | + | ||
199 | + def _get_special_words_by_vocab_type(self, vocab_type: VocabType) -> SpecialVocabWordsType: | ||
200 | + if not self.config.SEPARATE_OOV_AND_PAD: | ||
201 | + return _SpecialVocabWords_JoinedOovPad | ||
202 | + if vocab_type == VocabType.Target: | ||
203 | + return _SpecialVocabWords_OnlyOov | ||
204 | + return _SpecialVocabWords_SeparateOovPad | ||
205 | + | ||
206 | + def save(self, vocabularies_save_path: str): | ||
207 | + if vocabularies_save_path in self._already_saved_in_paths: | ||
208 | + return | ||
209 | + with open(vocabularies_save_path, 'wb') as file: | ||
210 | + self.token_vocab.save_to_file(file) | ||
211 | + self.target_vocab.save_to_file(file) | ||
212 | + self.path_vocab.save_to_file(file) | ||
213 | + self._already_saved_in_paths.add(vocabularies_save_path) | ||
214 | + | ||
215 | + def _load_word_freq_dict(self) -> Code2VecWordFreqDicts: | ||
216 | + assert self.config.is_training | ||
217 | + self.config.log('Loading word frequencies dictionaries from: %s ... ' % self.config.word_freq_dict_path) | ||
218 | + with open(self.config.word_freq_dict_path, 'rb') as file: | ||
219 | + token_to_count = pickle.load(file) | ||
220 | + path_to_count = pickle.load(file) | ||
221 | + target_to_count = pickle.load(file) | ||
222 | + self.config.log('Done loading word frequencies dictionaries.') | ||
223 | + | ||
224 | + return Code2VecWordFreqDicts( | ||
225 | + token_to_count=token_to_count, path_to_count=path_to_count, target_to_count=target_to_count) | ||
226 | + | ||
227 | + def get(self, vocab_type: VocabType) -> Vocab: | ||
228 | + if not isinstance(vocab_type, VocabType): | ||
229 | + raise ValueError('`vocab_type` should be `VocabType.Token`, `VocabType.Target` or `VocabType.Path`.') | ||
230 | + if vocab_type == VocabType.Token: | ||
231 | + return self.token_vocab | ||
232 | + if vocab_type == VocabType.Target: | ||
233 | + return self.target_vocab | ||
234 | + if vocab_type == VocabType.Path: | ||
235 | + return self.path_vocab |
code/crawler/crawler.py
0 → 100644
1 | +from github import Github | ||
2 | +import time | ||
3 | +import calendar | ||
4 | + | ||
5 | +DATASET_MAX = 1000 | ||
6 | + | ||
7 | +class GithubCrawler: | ||
8 | + def __init__(self, token): | ||
9 | + self._token = token | ||
10 | + self._g = Github(token) | ||
11 | + | ||
12 | + def getTimeLimit(self): | ||
13 | + core_rate_limit = self._g.get_rate_limit().core | ||
14 | + reset_timestamp = calendar.timegm(core_rate_limit.reset.timetuple()) | ||
15 | + sleep_time = reset_timestamp - calendar.timegm(time.gmtime()) + 1 | ||
16 | + return sleep_time | ||
17 | + | ||
18 | + def search_repo(self, keywords, S = 0, E = DATASET_MAX): | ||
19 | + if type(keywords) == str: | ||
20 | + keywords = [keywords] #auto packing for one keyword | ||
21 | + | ||
22 | + query = '+'.join(keywords) + '+in:readme+in:description' | ||
23 | + result = self._g.search_repositories(query) | ||
24 | + | ||
25 | + ret = [] | ||
26 | + for i in range(S, E): | ||
27 | + while True: | ||
28 | + try: | ||
29 | + r = result[i] | ||
30 | + repoName = r.owner.login+'/'+r.name | ||
31 | + print("repo found", f"[{i}]:", repoName) | ||
32 | + ret.append(repoName) | ||
33 | + break | ||
34 | + except Exception: | ||
35 | + print("Rate Limit Exceeded... Retrying", f"{[i]}", "Limit Time:", self.getTimeLimit()) | ||
36 | + time.sleep(1) | ||
37 | + | ||
38 | + return ret | ||
39 | + | ||
40 | + def search_files(self, repo_url, downloadLink = False): | ||
41 | + while True: | ||
42 | + try: | ||
43 | + repo = self._g.get_repo(repo_url) | ||
44 | + break | ||
45 | + except Exception as e: | ||
46 | + if '403' in str(e): | ||
47 | + print("Rate Limit Exceeded... Retrying", f"{[i]}", "Limit Time:", self.getTimeLimit()) | ||
48 | + time.sleep(1) | ||
49 | + continue | ||
50 | + print(e) | ||
51 | + return [] | ||
52 | + | ||
53 | + try: | ||
54 | + contents = repo.get_contents("") | ||
55 | + except Exception: #empty repo | ||
56 | + return [] | ||
57 | + | ||
58 | + files = [] | ||
59 | + | ||
60 | + while contents: | ||
61 | + file_content = contents.pop(0) | ||
62 | + if file_content.type == 'dir': | ||
63 | + if 'lib' in file_content.path: #python lib is in repo (too many files) | ||
64 | + return [] | ||
65 | + contents.extend(repo.get_contents(file_content.path)) | ||
66 | + else: | ||
67 | + if downloadLink: | ||
68 | + files.append(file_content.download_url) | ||
69 | + else: | ||
70 | + files.append(file_content.path) | ||
71 | + | ||
72 | + return files | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/crawler/main.py
0 → 100644
1 | +import crawler | ||
2 | +import os | ||
3 | +import utils | ||
4 | + | ||
5 | +TOKEN = 'YOUR_TOKEN_HERE' | ||
6 | +DATASET_DIR = 'YOUR_PATH_HERE' | ||
7 | +REPO_PATH = 'repos.txt' | ||
8 | + | ||
9 | +utils.removeEmptyDirectories(DATASET_DIR) | ||
10 | + | ||
11 | +c = crawler.GithubCrawler(TOKEN) | ||
12 | + | ||
13 | +if not os.path.exists(REPO_PATH): | ||
14 | + repos = c.search_repo('MNIST+language:python', 1000, 2000) | ||
15 | + f = open(REPO_PATH, 'w') | ||
16 | + for r in repos: | ||
17 | + f.write(r + '\n') | ||
18 | + f.close() | ||
19 | +else: | ||
20 | + f = open(REPO_PATH, 'r') | ||
21 | + repos = f.readlines() | ||
22 | + f.close() | ||
23 | + | ||
24 | +S = 0 | ||
25 | +L = len(repos) | ||
26 | +print("Found repositories:", L) | ||
27 | + | ||
28 | +for i in range(S, L): | ||
29 | + r = repos[i].strip() | ||
30 | + savename = r.replace('/', '_') | ||
31 | + print('Downloading', f'[{i}] :', savename) | ||
32 | + | ||
33 | + if os.path.exists(os.path.join(DATASET_DIR, savename)): | ||
34 | + continue | ||
35 | + | ||
36 | + files = c.search_files(r, True) | ||
37 | + files = list(filter(lambda x : utils.isformat(x, ['py', 'ipynb']), files)) | ||
38 | + if len(files) > 0: | ||
39 | + utils.downloadFiles(DATASET_DIR, savename, files) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/crawler/utils.py
0 → 100644
1 | +import os | ||
2 | +from requests import get | ||
3 | + | ||
4 | +def isformat(file, typenames): | ||
5 | + if type(file) != str: | ||
6 | + return False | ||
7 | + | ||
8 | + if type(typenames) == str: | ||
9 | + typenames = [typenames] | ||
10 | + | ||
11 | + dot = file.rfind('.') | ||
12 | + | ||
13 | + if dot < 0: | ||
14 | + for t in typenames: | ||
15 | + if file == t: | ||
16 | + return True | ||
17 | + return False | ||
18 | + | ||
19 | + ext = file[dot + 1 :] | ||
20 | + | ||
21 | + for t in typenames: | ||
22 | + if ext == t: | ||
23 | + return True | ||
24 | + | ||
25 | + return False | ||
26 | + | ||
27 | +def downloadFiles(root, dir, urls): | ||
28 | + if not os.path.exists(root): | ||
29 | + os.mkdir(root) | ||
30 | + | ||
31 | + path = os.path.join(root, dir) | ||
32 | + | ||
33 | + if not os.path.exists(path): | ||
34 | + os.mkdir(path) | ||
35 | + else: | ||
36 | + return | ||
37 | + | ||
38 | + for url in urls: | ||
39 | + name = os.path.basename(url) | ||
40 | + with open(os.path.join(path, name), 'wb') as f: | ||
41 | + try: | ||
42 | + response = get(url) | ||
43 | + f.write(response.content) | ||
44 | + | ||
45 | + except Exception as e: | ||
46 | + print(e) | ||
47 | + f.close() | ||
48 | + break | ||
49 | + | ||
50 | + f.close() | ||
51 | + | ||
52 | +def removeEmptyDirectories(root): | ||
53 | + cnt = 0 | ||
54 | + for dir in os.listdir(root): | ||
55 | + d = os.path.join(root, dir) | ||
56 | + if len(os.listdir(d)) == 0: #empty | ||
57 | + os.rmdir(d) | ||
58 | + cnt += 1 | ||
59 | + | ||
60 | + print(cnt, "empty directories removed") | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/block.py
0 → 100644
1 | +class Block: | ||
2 | + def __init__(self, type, line=''): | ||
3 | + self.blocks = list() | ||
4 | + self.code = line | ||
5 | + self.blockType = type | ||
6 | + self.indent = -1 | ||
7 | + | ||
8 | + def setIndent(self, indent): | ||
9 | + self.indent = indent | ||
10 | + | ||
11 | + def addLine(self, line): | ||
12 | + if len(self.code) > 0: | ||
13 | + self.code += '\n' | ||
14 | + self.code += line | ||
15 | + | ||
16 | + def addBlock(self, block): | ||
17 | + self.blocks.append(block) | ||
18 | + | ||
19 | + def debug(self): | ||
20 | + if self.blockType != 'TYPE_NORMAL': | ||
21 | + print("Block Info:", self.blockType, self.indent) | ||
22 | + print(self.code) | ||
23 | + | ||
24 | + for block in self.blocks: | ||
25 | + if block.indent <= self.indent: | ||
26 | + raise ValueError("Invalid Indent Error Occurred: {}, INDENT {} included in {}, INDENT {}".format(block.code, block.indent, self.code, self.indent)) | ||
27 | + block.debug() | ||
28 | + | ||
29 | + def __str__(self): | ||
30 | + if len(self.code) > 0: | ||
31 | + result = self.code + '\n' | ||
32 | + else: | ||
33 | + result = '' | ||
34 | + | ||
35 | + for block in self.blocks: | ||
36 | + result += block.__str__() | ||
37 | + | ||
38 | + return result | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/data_merger.py
0 → 100644
1 | +from utils import * | ||
2 | +import file_parser | ||
3 | +import random | ||
4 | + | ||
5 | +def merge_two_files(input, output): # pick two random files from input, merge and shuffle codes, print to output | ||
6 | + ori_files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
7 | + files = ori_files.copy() | ||
8 | + random.shuffle(files) | ||
9 | + | ||
10 | + os.makedirs(output, exist_ok=True) # create the output directory if not exists | ||
11 | + log = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8') | ||
12 | + | ||
13 | + index = 1 | ||
14 | + while len(files) > 0: | ||
15 | + if len(files) == 1: | ||
16 | + one = random.choice(ori_files) | ||
17 | + while one == files[0]: # why python doesn't have do while loop?? | ||
18 | + one = random.choice(ori_files) | ||
19 | + | ||
20 | + pick = [files[0], one] | ||
21 | + else: | ||
22 | + pick = files[:2] | ||
23 | + | ||
24 | + files = files[2:] | ||
25 | + | ||
26 | + lines1 = read_file(pick[0]) | ||
27 | + lines2 = read_file(pick[1]) | ||
28 | + | ||
29 | + print("Merging:", pick[0], pick[1]) | ||
30 | + | ||
31 | + block1 = file_parser.parse_block(lines1) | ||
32 | + block2 = file_parser.parse_block(lines2) | ||
33 | + | ||
34 | + for b in block2.blocks: | ||
35 | + block1.addBlock(b) | ||
36 | + | ||
37 | + shuffle_block(block1) | ||
38 | + write_block(os.path.join(output, '{}.py'.format(index)), block1) | ||
39 | + | ||
40 | + log.write('{}.py {} {}\n'.format(index, pick[0], pick[1])) | ||
41 | + index += 1 | ||
42 | + | ||
43 | + log.close() | ||
44 | + print("Done generating Merged Dataset") | ||
45 | + print("log.txt generated in output path, for merged file info. [merge_file_name file1 file2]") | ||
46 | + | ||
47 | + | ||
48 | +''' | ||
49 | + Usage: merge_two_files('data/original', 'data/merged') | ||
50 | +''' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/data_obfuscator.py
0 → 100644
1 | +from utils import * | ||
2 | +import file_parser | ||
3 | +import re | ||
4 | + | ||
5 | +# obfuscator v1 uses names from other methods (shuffles method names) | ||
6 | + | ||
7 | +def detect_vars(line): # detect variables and return range tuples. except for keywords | ||
8 | + ret = list() | ||
9 | + s = 0 | ||
10 | + e = 0 | ||
11 | + detected = False | ||
12 | + strException = False | ||
13 | + strCh = None | ||
14 | + line += ' ' # for last separator | ||
15 | + | ||
16 | + for i in range(len(line)): | ||
17 | + c = line[i] | ||
18 | + | ||
19 | + if not strException and (c == "'" or c == '"'): # we cannot remove string first, because index gets changed | ||
20 | + strCh = c | ||
21 | + strException = True | ||
22 | + continue | ||
23 | + | ||
24 | + if strException: | ||
25 | + if c == strCh: | ||
26 | + strException = False | ||
27 | + continue | ||
28 | + | ||
29 | + if not detected and re.match('[A-Za-z_]', c): | ||
30 | + detected = True | ||
31 | + s = i | ||
32 | + continue | ||
33 | + | ||
34 | + if detected and not re.match('[A-Za-z_0-9]', c): | ||
35 | + detected = False | ||
36 | + e = i | ||
37 | + ret.append((s, e)) | ||
38 | + | ||
39 | + return ret | ||
40 | + | ||
41 | +def obfuscate(lines, vars, dictionary, mapper): # obfuscate the code | ||
42 | + ret = list() | ||
43 | + ### write_file('D:/Develop/ori.py', lines) | ||
44 | + | ||
45 | + for line in lines: | ||
46 | + var_ranges = detect_vars(line) | ||
47 | + var_ranges = [(s, e) for (s, e) in var_ranges if line[s:e] in vars] # remove keywords (do not convert to words because of string exception) | ||
48 | + var_ranges.append((-1, -1)) # for out-of-range exception | ||
49 | + | ||
50 | + var_index = 0 | ||
51 | + new_line = '' | ||
52 | + i = 0 | ||
53 | + L = len(line) | ||
54 | + | ||
55 | + while i < L: | ||
56 | + if i == var_ranges[var_index][0]: # found var | ||
57 | + s, e = var_ranges[var_index] | ||
58 | + new_line += vars[mapper[dictionary[line[s:e]]]] | ||
59 | + i = e | ||
60 | + var_index += 1 | ||
61 | + else: | ||
62 | + new_line += line[i] | ||
63 | + i += 1 | ||
64 | + | ||
65 | + ret.append(new_line) | ||
66 | + | ||
67 | + ### write_file('D:/Develop/obf.py', ret) | ||
68 | + return ret | ||
69 | + | ||
70 | +def create_var_histogram(input, outPath): | ||
71 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
72 | + freq_dict = dict() | ||
73 | + | ||
74 | + for p in files: | ||
75 | + lines = read_file(p) | ||
76 | + lines = remove_unnecessary_comments(lines) | ||
77 | + | ||
78 | + for line in lines: | ||
79 | + file_parser.parse_keywords(line, freq_dict) | ||
80 | + | ||
81 | + hist = open(outPath, 'w', encoding='utf8') | ||
82 | + arr = sorted(freq_dict.items(), key=select_value) | ||
83 | + for i in arr: | ||
84 | + hist.write(str(i) + '\n') | ||
85 | + hist.close() | ||
86 | + | ||
87 | +def read_histogram(inputPath): | ||
88 | + lines = read_file(inputPath) | ||
89 | + ret = [] | ||
90 | + | ||
91 | + for line in lines: | ||
92 | + line = line.split("'")[1] | ||
93 | + ret.append(line) | ||
94 | + return ret | ||
95 | + | ||
96 | +def obfuscate_files(input, output, var=None, threshold=4000): # obfuscate variables. Guessing variable names from keyword frequency (threshold) if variable list is not given | ||
97 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
98 | + freq_dict = dict() | ||
99 | + codes = list() | ||
100 | + | ||
101 | + for p in files: | ||
102 | + lines = read_file(p) | ||
103 | + | ||
104 | + lines = remove_unnecessary_comments(lines) # IMPORTANT: remove comments from lines for preprocessing | ||
105 | + codes.append((p, lines)) | ||
106 | + | ||
107 | + if var == None: | ||
108 | + for line in lines: | ||
109 | + file_parser.parse_keywords(line, freq_dict) | ||
110 | + | ||
111 | + | ||
112 | + if var == None: # don't have variable list | ||
113 | + hist = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8') | ||
114 | + arr = sorted(freq_dict.items(), key=select_value) | ||
115 | + for i in arr: | ||
116 | + hist.write(str(i) + '\n') | ||
117 | + hist.close() | ||
118 | + | ||
119 | + var, _ = threshold_dict(freq_dict, threshold) | ||
120 | + var = [v[0] for v in var] | ||
121 | + | ||
122 | + dictionary = create_dictionary(var) | ||
123 | + mapper = create_mapper(len(var)) | ||
124 | + | ||
125 | + ### obfuscate(codes[0][1], var, dictionary, mapper) | ||
126 | + | ||
127 | + for path, code in codes: | ||
128 | + obfuscated = obfuscate(code, var, dictionary, mapper) | ||
129 | + | ||
130 | + filepath = path.split(input)[1][1:] | ||
131 | + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists | ||
132 | + new_path = os.path.join(output, filepath) | ||
133 | + write_file(new_path, obfuscated) | ||
134 | + | ||
135 | + print("Done generating Obfuscated Dataset") | ||
136 | + | ||
137 | + | ||
138 | +''' | ||
139 | +Usage | ||
140 | +obfuscate_files('data/original', 'data/obfuscated') | ||
141 | +''' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/data_obfuscator_v2.py
0 → 100644
1 | +from utils import * | ||
2 | +import file_parser | ||
3 | +import re | ||
4 | + | ||
5 | +# obfuscator v2 generate random name for methods | ||
6 | + | ||
7 | +def random_character(start=False): | ||
8 | + if start: | ||
9 | + x = random.randint(0, 52) | ||
10 | + if x == 0: | ||
11 | + return '_' | ||
12 | + elif x <= 26: | ||
13 | + return chr(65 + x - 1) | ||
14 | + else: | ||
15 | + return chr(97 + x - 27) | ||
16 | + | ||
17 | + x = random.randint(0, 62) | ||
18 | + if x == 0: | ||
19 | + return '_' | ||
20 | + elif x <= 26: | ||
21 | + return chr(65 + x - 1) | ||
22 | + elif x <= 52: | ||
23 | + return chr(97 + x - 27) | ||
24 | + else: | ||
25 | + return str(x - 53) | ||
26 | + | ||
27 | + | ||
28 | +def create_mapper_v2(L): | ||
29 | + ret = [] | ||
30 | + while len(ret) < L: | ||
31 | + length = random.randint(0, 8) + 4 | ||
32 | + s = random_character(True) | ||
33 | + | ||
34 | + while len(s) < length: | ||
35 | + s += random_character() | ||
36 | + | ||
37 | + if not s in ret: | ||
38 | + ret.append(s) | ||
39 | + | ||
40 | + return ret | ||
41 | + | ||
42 | +def detect_vars(line): # detect variables and return range tuples. except for keywords | ||
43 | + ret = list() | ||
44 | + s = 0 | ||
45 | + e = 0 | ||
46 | + detected = False | ||
47 | + strException = False | ||
48 | + strCh = None | ||
49 | + line += ' ' # for last separator | ||
50 | + | ||
51 | + for i in range(len(line)): | ||
52 | + c = line[i] | ||
53 | + | ||
54 | + if not strException and (c == "'" or c == '"'): # we cannot remove string first, because index gets changed | ||
55 | + strCh = c | ||
56 | + strException = True | ||
57 | + continue | ||
58 | + | ||
59 | + if strException: | ||
60 | + if c == strCh: | ||
61 | + strException = False | ||
62 | + continue | ||
63 | + | ||
64 | + if not detected and re.match('[A-Za-z_]', c): | ||
65 | + detected = True | ||
66 | + s = i | ||
67 | + continue | ||
68 | + | ||
69 | + if detected and not re.match('[A-Za-z_0-9]', c): | ||
70 | + detected = False | ||
71 | + e = i | ||
72 | + ret.append((s, e)) | ||
73 | + | ||
74 | + return ret | ||
75 | + | ||
76 | +def obfuscate(lines, vars, dictionary, mapper): # obfuscate the code | ||
77 | + ret = list() | ||
78 | + ### write_file('D:/Develop/ori.py', lines) | ||
79 | + | ||
80 | + for line in lines: | ||
81 | + var_ranges = detect_vars(line) | ||
82 | + var_ranges = [(s, e) for (s, e) in var_ranges if line[s:e] in vars] # remove keywords (do not convert to words because of string exception) | ||
83 | + var_ranges.append((-1, -1)) # for out-of-range exception | ||
84 | + | ||
85 | + var_index = 0 | ||
86 | + new_line = '' | ||
87 | + i = 0 | ||
88 | + L = len(line) | ||
89 | + | ||
90 | + while i < L: | ||
91 | + if i == var_ranges[var_index][0]: # found var | ||
92 | + s, e = var_ranges[var_index] | ||
93 | + new_line += mapper[dictionary[line[s:e]]] | ||
94 | + i = e | ||
95 | + var_index += 1 | ||
96 | + else: | ||
97 | + new_line += line[i] | ||
98 | + i += 1 | ||
99 | + | ||
100 | + ret.append(new_line) | ||
101 | + | ||
102 | + ### write_file('D:/Develop/obf.py', ret) | ||
103 | + return ret | ||
104 | + | ||
105 | +def create_var_histogram(input, outPath): | ||
106 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
107 | + freq_dict = dict() | ||
108 | + | ||
109 | + for p in files: | ||
110 | + lines = read_file(p) | ||
111 | + lines = remove_unnecessary_comments(lines) | ||
112 | + | ||
113 | + for line in lines: | ||
114 | + file_parser.parse_keywords(line, freq_dict) | ||
115 | + | ||
116 | + hist = open(outPath, 'w', encoding='utf8') | ||
117 | + arr = sorted(freq_dict.items(), key=select_value) | ||
118 | + for i in arr: | ||
119 | + hist.write(str(i) + '\n') | ||
120 | + hist.close() | ||
121 | + | ||
122 | +def read_histogram(inputPath): | ||
123 | + lines = read_file(inputPath) | ||
124 | + ret = [] | ||
125 | + | ||
126 | + for line in lines: | ||
127 | + line = line.split("'")[1] | ||
128 | + ret.append(line) | ||
129 | + return ret | ||
130 | + | ||
131 | +def obfuscate_files(input, output, var=None, threshold=4000): # obfuscate variables. Guessing variable names from keyword frequency (threshold) if variable list is not given | ||
132 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
133 | + freq_dict = dict() | ||
134 | + codes = list() | ||
135 | + | ||
136 | + for p in files: | ||
137 | + lines = read_file(p) | ||
138 | + | ||
139 | + lines = remove_unnecessary_comments(lines) # IMPORTANT: remove comments from lines for preprocessing | ||
140 | + codes.append((p, lines)) | ||
141 | + | ||
142 | + if var == None: | ||
143 | + for line in lines: | ||
144 | + file_parser.parse_keywords(line, freq_dict) | ||
145 | + | ||
146 | + | ||
147 | + if var == None: # don't have variable list | ||
148 | + hist = open(os.path.join(output, 'log.txt'), 'w', encoding='utf8') | ||
149 | + arr = sorted(freq_dict.items(), key=select_value) | ||
150 | + for i in arr: | ||
151 | + hist.write(str(i) + '\n') | ||
152 | + hist.close() | ||
153 | + | ||
154 | + var, _ = threshold_dict(freq_dict, threshold) | ||
155 | + var = [v[0] for v in var] | ||
156 | + | ||
157 | + dictionary = create_dictionary(var) | ||
158 | + mapper = create_mapper_v2(len(var)) | ||
159 | + | ||
160 | + ### obfuscate(codes[0][1], var, dictionary, mapper) | ||
161 | + | ||
162 | + for path, code in codes: | ||
163 | + obfuscated = obfuscate(code, var, dictionary, mapper) | ||
164 | + | ||
165 | + filepath = path.split(input)[1][1:] | ||
166 | + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists | ||
167 | + new_path = os.path.join(output, filepath) | ||
168 | + write_file(new_path, obfuscated) | ||
169 | + | ||
170 | + print("Done generating Obfuscated Dataset") | ||
171 | + | ||
172 | + | ||
173 | +''' | ||
174 | +Usage | ||
175 | +obfuscate_files('data/original', 'data/obfuscated') | ||
176 | +''' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/data_refiner.py
0 → 100644
1 | +from utils import * | ||
2 | +import file_parser | ||
3 | +import random | ||
4 | + | ||
5 | +def refine_files(input, output): | ||
6 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
7 | + random.shuffle(files) | ||
8 | + | ||
9 | + for p in files: | ||
10 | + lines = read_file(p) | ||
11 | + | ||
12 | + print("Refining:", p) | ||
13 | + block = file_parser.parse_block(lines) | ||
14 | + | ||
15 | + filepath = p.split(input)[1][1:] | ||
16 | + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists | ||
17 | + path = os.path.join(output, filepath) | ||
18 | + write_block(path, block) | ||
19 | + | ||
20 | + print("Done generating Refined Dataset") | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/data_shuffler.py
0 → 100644
1 | +from utils import * | ||
2 | +import file_parser | ||
3 | +import random | ||
4 | + | ||
5 | +def shuffle_files(input, output): # pick random file and shuffle code order to output | ||
6 | + files = [f for f in readdir(input) if is_extension(f, 'py')] | ||
7 | + random.shuffle(files) | ||
8 | + | ||
9 | + for p in files: | ||
10 | + lines = read_file(p) | ||
11 | + | ||
12 | + print("Shuffling:", p) | ||
13 | + block = file_parser.parse_block(lines) | ||
14 | + shuffle_block(block) | ||
15 | + | ||
16 | + filepath = p.split(input)[1][1:] | ||
17 | + os.makedirs(os.path.join(output, filepath.split('\\')[0]), exist_ok=True) # create the output directory if not exists | ||
18 | + path = os.path.join(output, filepath) | ||
19 | + write_block(path, block) | ||
20 | + | ||
21 | + print("Done generating Shuffled Dataset") | ||
22 | + | ||
23 | + | ||
24 | +''' | ||
25 | +shuffle_files('data/original', 'data/shuffled') | ||
26 | +''' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/file_parser.py
0 → 100644
1 | +from utils import * | ||
2 | +import re | ||
3 | +import keyword | ||
4 | + | ||
5 | +''' | ||
6 | + Test multi-line comments | ||
7 | +''' | ||
8 | + | ||
9 | +LIBRARYS = list() | ||
10 | + | ||
11 | +def parse_keywords(line, out): # out : output dictionary to sum up frequencies | ||
12 | + line = line.strip() | ||
13 | + line = remove_string(line) | ||
14 | + result = '' | ||
15 | + | ||
16 | + for c in line: | ||
17 | + if re.match('[A-Za-z_@0-9]', c): | ||
18 | + result += c | ||
19 | + else: | ||
20 | + result += ' ' | ||
21 | + | ||
22 | + import_line = False | ||
23 | + prev_key = '' | ||
24 | + | ||
25 | + for key in result.split(' '): | ||
26 | + if not key or is_number(key) or key[0] in "0123456789": | ||
27 | + continue | ||
28 | + | ||
29 | + ## Exception code here | ||
30 | + | ||
31 | + if key in ['from', 'import']: | ||
32 | + import_line = True | ||
33 | + | ||
34 | + if import_line and prev_key != 'as': | ||
35 | + if not key in LIBRARYS: | ||
36 | + LIBRARYS.append(key) | ||
37 | + prev_key = key | ||
38 | + continue | ||
39 | + | ||
40 | + if key in keyword.kwlist or key in LIBRARYS or '@' in key: | ||
41 | + prev_key = key | ||
42 | + continue | ||
43 | + | ||
44 | + prev_key = key | ||
45 | + | ||
46 | + ## | ||
47 | + | ||
48 | + if not key in out: | ||
49 | + out[key] = 1 | ||
50 | + else: | ||
51 | + out[key] += 1 | ||
52 | + | ||
53 | +def parse_block(lines): # parse to import / def / class / normal (if, for, etc) | ||
54 | + lines = remove_unnecessary_comments(lines) | ||
55 | + root = Block('TYPE_ROOT') # main block tree node | ||
56 | + block_stack = [root] | ||
57 | + i = 0 | ||
58 | + L = len(lines) | ||
59 | + # par_stack = list() | ||
60 | + # multi_string_stack = list() | ||
61 | + | ||
62 | + while i < L: | ||
63 | + line = lines[i] | ||
64 | + start_index = 0 | ||
65 | + indent_count = 0 | ||
66 | + | ||
67 | + while True: # count indents | ||
68 | + if line[start_index] == '\t': | ||
69 | + start_index += 1 | ||
70 | + indent_count += 4 | ||
71 | + elif line[start_index] == ' ': | ||
72 | + start_index += 1 | ||
73 | + indent_count += 1 | ||
74 | + else: | ||
75 | + break | ||
76 | + | ||
77 | + block = create_block_from_line(line) | ||
78 | + block.setIndent(indent_count) | ||
79 | + | ||
80 | + if block.blockType == 'TYPE_FACTORY': # for @factory proeprty exception | ||
81 | + i += 1 | ||
82 | + | ||
83 | + temp = create_block_from_line(lines[i]) | ||
84 | + if temp.blockType == 'TYPE_CLASS': | ||
85 | + block.addLine(lines[i]) | ||
86 | + block.blockType = 'TYPE_CLASS' | ||
87 | + elif temp.blockType == 'TYPE_DEF': | ||
88 | + block.addLine(lines[i]) | ||
89 | + block.blockType = 'TYPE_DEF' | ||
90 | + else: # unknown type exception (factory single lines, or multi line code) | ||
91 | + i -= 1 # roll back | ||
92 | + | ||
93 | + ''' | ||
94 | + ### code for multi-line string/code detection, but too many exception. (most code works well due to indent parsing) | ||
95 | + line = lines[i] | ||
96 | + if detect_parenthesis(line, par_stack) or detect_multi_string(line, multi_string_stack) or detect_multi_line_code(lines[i]): # code is not ended in a single line | ||
97 | + i += 1 | ||
98 | + while detect_parenthesis(lines[i], par_stack) or detect_multi_string(lines[i], multi_string_stack) or detect_multi_line_code(lines[i]): | ||
99 | + block.addLine(lines[i]) | ||
100 | + i += 1 | ||
101 | + | ||
102 | + block.addLine(lines[i]) | ||
103 | + ''' | ||
104 | + | ||
105 | + if indent_count == block_stack[-1].indent: # same indent -> change the block | ||
106 | + block_stack.pop() | ||
107 | + block_stack[-1].addBlock(block) | ||
108 | + block_stack.append(block) | ||
109 | + elif indent_count > block_stack[-1].indent: # block included in previous block | ||
110 | + block_stack[-1].addBlock(block) | ||
111 | + block_stack.append(block) | ||
112 | + else: # block ended | ||
113 | + while indent_count <= block_stack[-1].indent: | ||
114 | + block_stack.pop() | ||
115 | + block_stack[-1].addBlock(block) | ||
116 | + block_stack.append(block) | ||
117 | + i += 1 | ||
118 | + | ||
119 | + return root | ||
120 | + | ||
121 | + | ||
122 | +""" | ||
123 | + Usage | ||
124 | + | ||
125 | + path = 'data/test.py' | ||
126 | + f = open(path, 'r') | ||
127 | + lines = f.readlines() | ||
128 | + f.close() | ||
129 | + | ||
130 | + | ||
131 | + block = parse_block(lines) | ||
132 | + block.debug() | ||
133 | + | ||
134 | + | ||
135 | + ''' | ||
136 | + keywords = dict() | ||
137 | + parse_keywords(lines, keywords) | ||
138 | + | ||
139 | + for k, v in keywords.items(): | ||
140 | + print(k,':',v) | ||
141 | + | ||
142 | + a, b = threshold_dict(keywords, 3) | ||
143 | + | ||
144 | + print(a) | ||
145 | + print(b) | ||
146 | + ''' | ||
147 | +""" | ||
148 | + | ||
149 | +''' | ||
150 | +d = dict() | ||
151 | +parse_keywords('from test.library import a as x, b as y', d) | ||
152 | +print(d) | ||
153 | +''' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/dataset_generator/main.py
0 → 100644
1 | +from utils import remove_string | ||
2 | +import utils | ||
3 | +import data_merger | ||
4 | +import data_refiner | ||
5 | +import data_shuffler | ||
6 | +import file_parser | ||
7 | +import data_obfuscator_v2 | ||
8 | + | ||
9 | +if __name__ == '__main__': | ||
10 | + input_path = 'data/original' | ||
11 | + data_refiner.refine_files(input_path, 'data/refined') | ||
12 | + data_merger.merge_two_files(input_path, 'data/merged') | ||
13 | + data_shuffler.shuffle_files(input_path, 'data/shuffled') | ||
14 | + vars = data_obfuscator_v2.read_histogram('data/histogram_v1.txt') | ||
15 | + data_obfuscator_v2.obfuscate_files(input_path, 'data/obfuscated2', vars) | ||
16 | + | ||
17 | + # utils.write_file('data/keyword_examples.txt', utils.search_keyword(input_path, 'rand')) | ||
18 | + # data_obfuscator.create_var_histogram(input_path, 'data/histogram.txt') |
code/dataset_generator/utils.py
0 → 100644
1 | +from block import Block | ||
2 | +import bisect | ||
3 | +import os | ||
4 | +import re | ||
5 | +import random | ||
6 | + | ||
7 | +TYPE_CLASS = ['class'] | ||
8 | +TYPE_DEF = ['def'] | ||
9 | +TYPE_IMPORT = ['from', 'import'] | ||
10 | +TYPE_CONDITOIN = ['if', 'elif', 'else', 'for', 'while', 'with'] | ||
11 | +multi_line_comments = ["'''", '"""'] | ||
12 | + | ||
13 | +def select_value(x): | ||
14 | + return x[1] | ||
15 | + | ||
16 | +def threshold_dict(d, val): # split dict in two by thesholding value | ||
17 | + arr = sorted(d.items(), key=select_value) | ||
18 | + index = bisect.bisect_left([r[1] for r in arr], val) | ||
19 | + return arr[:index], arr[index:] | ||
20 | + | ||
21 | +def is_number(s): | ||
22 | + if s[0] == '-': | ||
23 | + s = s[1:] | ||
24 | + return s.replace('.','',1).isdigit() | ||
25 | + | ||
26 | +def is_extension(f, ext): | ||
27 | + return os.path.splitext(f)[1][1:] == ext | ||
28 | + | ||
29 | +def _readdir_r(dirpath): # readdir for recursive | ||
30 | + ret = [] | ||
31 | + for f in os.listdir(dirpath): | ||
32 | + ret.append(os.path.join(dirpath, f)) | ||
33 | + | ||
34 | + return ret | ||
35 | + | ||
36 | +def readdir(path): # read files from the directory | ||
37 | + pathList = [path] | ||
38 | + result = [] | ||
39 | + i = 0 | ||
40 | + | ||
41 | + while i < len(pathList): | ||
42 | + f = pathList[i] | ||
43 | + if os.path.isdir(f): | ||
44 | + pathList += _readdir_r(f) | ||
45 | + else: | ||
46 | + result.append(f) | ||
47 | + | ||
48 | + i += 1 | ||
49 | + | ||
50 | + return result | ||
51 | + | ||
52 | +def remove_string(line): | ||
53 | + strIn = False | ||
54 | + strCh = None | ||
55 | + result = '' | ||
56 | + i = 0 | ||
57 | + L = len(line) | ||
58 | + | ||
59 | + while i < L: | ||
60 | + if i + 3 < L: | ||
61 | + if line[i:i+3] in multi_line_comments: | ||
62 | + if not strIn: | ||
63 | + strIn = True | ||
64 | + strCh = line[i:i+3] | ||
65 | + elif line[i:i+3] == strCh: | ||
66 | + strIn = False | ||
67 | + | ||
68 | + i += 2 | ||
69 | + continue | ||
70 | + | ||
71 | + c = line[i] | ||
72 | + i += 1 | ||
73 | + | ||
74 | + if c == '\'' or c == '\"': | ||
75 | + if not strIn: | ||
76 | + strIn = True | ||
77 | + strCh = c | ||
78 | + elif c == strCh: | ||
79 | + strIn = False | ||
80 | + continue | ||
81 | + | ||
82 | + if strIn: | ||
83 | + continue | ||
84 | + | ||
85 | + result += c | ||
86 | + | ||
87 | + return result | ||
88 | + | ||
89 | +def using_multi_string(line, index): | ||
90 | + line = line.strip() | ||
91 | + for comment in multi_line_comments: | ||
92 | + if line.find(comment, index) > 0: | ||
93 | + return True | ||
94 | + return False | ||
95 | + | ||
96 | +def remove_unnecessary_comments(lines): | ||
97 | + # Warning : cannot detect all multi-line comments, because it exactly is multi-line string. | ||
98 | + | ||
99 | + #TODO: multi line string parser will not work well when using strings (and comments, also) more than one. | ||
100 | + # ex) a = ''' d ''' + ''' | ||
101 | + # abc ''' + ''' | ||
102 | + # x''' | ||
103 | + | ||
104 | + result = [] | ||
105 | + multi_line = False | ||
106 | + multi_string = False | ||
107 | + strCh = None | ||
108 | + | ||
109 | + for line in lines: | ||
110 | + find_str_index = 0 | ||
111 | + if multi_string: | ||
112 | + if strCh in line: | ||
113 | + find_str_index = line.find(strCh) + 3 | ||
114 | + multi_string = False | ||
115 | + strCh = None | ||
116 | + | ||
117 | + result.append(line) | ||
118 | + continue | ||
119 | + | ||
120 | + if multi_line: # parsing multi-line comments | ||
121 | + if strCh in line: | ||
122 | + multi_line = False | ||
123 | + strCh = None | ||
124 | + continue | ||
125 | + | ||
126 | + if using_multi_string(line, find_str_index): | ||
127 | + i1 = line.find(multi_line_comments[0]) | ||
128 | + i2 = line.find(multi_line_comments[1]) | ||
129 | + | ||
130 | + if i1 < 0: | ||
131 | + i1 = len(line) + 1 | ||
132 | + if i2 < 0: | ||
133 | + i2 = len(line) + 1 | ||
134 | + | ||
135 | + if i1 < i2: | ||
136 | + strCh = multi_line_comments[0] | ||
137 | + else: | ||
138 | + strCh = multi_line_comments[1] | ||
139 | + | ||
140 | + result.append(line) | ||
141 | + if line.count(strCh) % 2 != 0: | ||
142 | + multi_string = True | ||
143 | + continue | ||
144 | + | ||
145 | + code = line.strip() | ||
146 | + | ||
147 | + if code[:3] in multi_line_comments: # detect in-out of multi-line comments | ||
148 | + if code.count(code[:3]) % 2 != 0: # comment count in odd numbers (start or end of comment is in the line) | ||
149 | + multi_line = True | ||
150 | + strCh = code[:3] | ||
151 | + continue | ||
152 | + | ||
153 | + comment_index = line.find('#') | ||
154 | + if comment_index >= 0: # one line comment found | ||
155 | + line = line[:comment_index] | ||
156 | + line = line.rstrip() # remove rightmost spaces | ||
157 | + | ||
158 | + if len(line) == 0: # no code in this line | ||
159 | + continue | ||
160 | + | ||
161 | + result.append(line) # add to results | ||
162 | + | ||
163 | + return result | ||
164 | + | ||
165 | +def create_block_from_line(line): | ||
166 | + _line = remove_string(line) | ||
167 | + _line = _line.strip() | ||
168 | + | ||
169 | + if '@' in _line: | ||
170 | + return Block('TYPE_FACTORY', line) | ||
171 | + | ||
172 | + keywords = _line.split(' ') | ||
173 | + | ||
174 | + for key in keywords: | ||
175 | + if key in TYPE_IMPORT: | ||
176 | + return Block('TYPE_IMPORT', line) | ||
177 | + | ||
178 | + if key in TYPE_CLASS: | ||
179 | + return Block('TYPE_CLASS', line) | ||
180 | + | ||
181 | + if key in TYPE_DEF: | ||
182 | + return Block('TYPE_DEF', line) | ||
183 | + | ||
184 | + if key in TYPE_CONDITOIN: | ||
185 | + return Block('TYPE_CONDITION', line) | ||
186 | + | ||
187 | + return Block('TYPE_NORMAL', line) | ||
188 | + | ||
189 | +def create_dictionary(arr): # create index dictionary for str array | ||
190 | + ret = dict() | ||
191 | + | ||
192 | + key = 0 | ||
193 | + for name in arr: | ||
194 | + ret[name] = key | ||
195 | + key += 1 | ||
196 | + | ||
197 | + return ret | ||
198 | + | ||
199 | +def create_mapper(L): # create mapping array to match each index in range L | ||
200 | + arr = list(range(L)) | ||
201 | + random.shuffle(arr) | ||
202 | + ret = arr.copy() | ||
203 | + | ||
204 | + for i in range(L): | ||
205 | + ret[i] = arr[i] | ||
206 | + | ||
207 | + return ret | ||
208 | + | ||
209 | +def read_file(path): | ||
210 | + f = open(path, 'r', encoding='utf8') | ||
211 | + ret = f.readlines() | ||
212 | + f.close() | ||
213 | + return ret | ||
214 | + | ||
215 | +def write_file(path, lines): | ||
216 | + f = open(path, 'w', encoding='utf8') | ||
217 | + | ||
218 | + for line in lines: | ||
219 | + if '\n' in line: | ||
220 | + f.write(line) | ||
221 | + else: | ||
222 | + f.write(line + '\n') | ||
223 | + f.close() | ||
224 | + | ||
225 | +def write_block(path, block): | ||
226 | + f = open(path, 'w', encoding='utf8') | ||
227 | + f.write(str(block)) | ||
228 | + f.close() | ||
229 | + | ||
230 | +def shuffle_block(block): | ||
231 | + if block.blockType != 'TYPE_CLASS' and block.blockType != 'TYPE_ROOT': | ||
232 | + return | ||
233 | + | ||
234 | + for b in block.blocks: | ||
235 | + shuffle_block(b) | ||
236 | + | ||
237 | + random.shuffle(block.blocks) | ||
238 | + | ||
239 | +def detect_multi_string(line, stack): | ||
240 | + L = len(line) | ||
241 | + | ||
242 | + for i in range(L): | ||
243 | + if i + 3 > L: | ||
244 | + break | ||
245 | + | ||
246 | + s = line[i:i+3] | ||
247 | + if s in multi_line_comments: | ||
248 | + if len(stack) > 0 and stack[-1] == s: | ||
249 | + stack.pop() | ||
250 | + elif len(stack) == 0: | ||
251 | + stack.append(s) | ||
252 | + return len(stack) > 0 | ||
253 | + | ||
254 | +def detect_parenthesis(line, stack): | ||
255 | + line = remove_string(line) | ||
256 | + | ||
257 | + for c in line: | ||
258 | + if c == '(': | ||
259 | + stack.append(1) | ||
260 | + elif c == ')': | ||
261 | + stack.pop() | ||
262 | + | ||
263 | + if len(stack) > 0: | ||
264 | + print(line) | ||
265 | + return len(stack) > 0 | ||
266 | + | ||
267 | +def detect_multi_line_code(line): | ||
268 | + line = line.rstrip() | ||
269 | + return len(line) > 0 and line[-1] == '\\' | ||
270 | + | ||
271 | +def search_keyword(path, keyword, fast_detect=False): # detect just key string is included in the line if fast_detect is True | ||
272 | + files = [f for f in readdir(path) if is_extension(f, 'py')] | ||
273 | + result = list() | ||
274 | + | ||
275 | + for p in files: | ||
276 | + lines = read_file(p) | ||
277 | + lines = remove_unnecessary_comments(lines) | ||
278 | + | ||
279 | + for line in lines: | ||
280 | + | ||
281 | + if fast_detect: | ||
282 | + if keyword in line: | ||
283 | + result.append(line) | ||
284 | + continue | ||
285 | + | ||
286 | + x = '' | ||
287 | + for c in line: | ||
288 | + if re.match('[A-Za-z_@0-9]', c): | ||
289 | + x += c | ||
290 | + else: | ||
291 | + x += ' ' | ||
292 | + | ||
293 | + keywords = x.split(' ') | ||
294 | + if keyword in keywords: | ||
295 | + result.append(line) | ||
296 | + | ||
297 | + return result | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/config.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +MAX_SEQ_LENGTH = 384 | ||
4 | +BATCH_SIZE = 64 | ||
5 | +EPOCHS = 50 | ||
6 | + | ||
7 | +BASE_OUTPUT = "output/siamese" | ||
8 | + | ||
9 | +DATASET_PATH = "data/pair_dataset.npz" #path for generated pair dataset | ||
10 | +VECTOR_PATH = "data/vectors.npz" #path for feature vectors from code dataset | ||
11 | +EMBEDDING_PATH = "data/embedding.npz" #path for embedding vector | ||
12 | +MODEL_PATH = os.path.sep.join([BASE_OUTPUT, "siamese_model"]) | ||
13 | +PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"]) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/dataset.py
0 → 100644
1 | +import numpy as np | ||
2 | +import random | ||
3 | +import pandas as pd | ||
4 | +from keras.preprocessing.text import Tokenizer | ||
5 | +from utils import * | ||
6 | + | ||
7 | +def save_dataset(path, pairData, pairLabels, compressed=True): | ||
8 | + if compressed: | ||
9 | + np.savez_compressed(path, pairData=pairData, pairLabels=pairLabels) | ||
10 | + else: | ||
11 | + np.savez(path, pairData=pairData, pairLabels=pairLabels) | ||
12 | + | ||
13 | +def load_dataset(path): | ||
14 | + data = np.load(path, allow_pickle=True) | ||
15 | + return (data['pairData'], data['pairLabels']) | ||
16 | + | ||
17 | +def make_dataset_small(path): # couldn't make dataser for shuffled/merged/obfuscated, as memory run out. | ||
18 | + vecs = np.load(path, allow_pickle=True)['vecs'] | ||
19 | + | ||
20 | + pairData = [] | ||
21 | + pairLabels = [] # 1 for plagiarism | ||
22 | + | ||
23 | + # original pair | ||
24 | + for i in range(len(vecs)): | ||
25 | + currentData = vecs[i] | ||
26 | + | ||
27 | + pairData.append([currentData, currentData]) | ||
28 | + pairLabels.append([1]) | ||
29 | + | ||
30 | + j = i | ||
31 | + while j == i: | ||
32 | + j = random.randint(0, len(vecs) - 1) | ||
33 | + | ||
34 | + pairData.append([currentData, vecs[j]]) | ||
35 | + pairLabels.append([0]) | ||
36 | + | ||
37 | + return (np.array(pairData), np.array(pairLabels)) | ||
38 | + | ||
39 | +def load_embedding(path): | ||
40 | + data = np.load(path, allow_pickle=True) | ||
41 | + return (data['vocab_size'], data['embedding_matrix']) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/file_parser.py
0 → 100644
1 | +import re | ||
2 | +from utils import remove_string | ||
3 | + | ||
4 | +def parse_keywords(line): | ||
5 | + line = line.strip() | ||
6 | + line = remove_string(line) | ||
7 | + result = '' | ||
8 | + | ||
9 | + for c in line: | ||
10 | + if re.match('[A-Za-z_@0-9]', c): | ||
11 | + result += c | ||
12 | + else: | ||
13 | + result += ' ' | ||
14 | + | ||
15 | + return result.split(' ') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/model.py
0 → 100644
1 | +from tensorflow.python.keras import backend as K | ||
2 | +from tensorflow.keras.models import Model | ||
3 | +from tensorflow.keras.layers import Input | ||
4 | +from tensorflow.keras.layers import Layer | ||
5 | +from tensorflow.keras.layers import LSTM | ||
6 | +from tensorflow.keras.layers import Embedding | ||
7 | +from tensorflow.python.keras.layers.wrappers import Bidirectional | ||
8 | +from tensorflow.keras.models import Sequential | ||
9 | +from tensorflow.keras.optimizers import Adam | ||
10 | + | ||
11 | +class ManDist(Layer): | ||
12 | + def __init__(self, **kwargs): | ||
13 | + self.result = None | ||
14 | + super(ManDist, self).__init__(**kwargs) | ||
15 | + | ||
16 | + def build(self, input_shape): | ||
17 | + super(ManDist, self).build(input_shape) | ||
18 | + | ||
19 | + def call(self, x, **kwargs): | ||
20 | + self.result = K.exp(-K.sum(K.abs(x[0] - x[1]), axis=1, keepdims=True)) | ||
21 | + return self.result | ||
22 | + | ||
23 | + def compute_output_shape(self): | ||
24 | + return K.int_shape(self.result) | ||
25 | + | ||
26 | +def build_siamese_model(embedding_matrix, embeddingDim, max_sequence_length=384, number_lstm_units=50, rate_drop_lstm=0.01): | ||
27 | + | ||
28 | + x = Sequential() | ||
29 | + x.add(Embedding(len(embedding_matrix), embeddingDim, weights=[embedding_matrix], input_shape=(max_sequence_length,), trainable=False)) | ||
30 | + x.add(LSTM(number_lstm_units, dropout=rate_drop_lstm, return_sequences=True, activation='softmax')) | ||
31 | + | ||
32 | + input_1 = Input(shape=(max_sequence_length,), dtype='int32') | ||
33 | + input_2 = Input(shape=(max_sequence_length,), dtype='int32') | ||
34 | + | ||
35 | + distance = ManDist()([x(input_1), x(input_2)]) | ||
36 | + model = Model(inputs=[input_1, input_2], outputs=[distance]) | ||
37 | + model.compile(loss='mean_squared_error', optimizer=Adam(learning_rate=0.001), metrics=['accuracy']) | ||
38 | + | ||
39 | + return model | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/predict.py
0 → 100644
1 | +import config | ||
2 | +from tensorflow.keras.models import load_model | ||
3 | +from gensim.models import KeyedVectors | ||
4 | +from file_parser import parse_keywords | ||
5 | +import tensorflow as tf | ||
6 | +from utils import * | ||
7 | +import random | ||
8 | +import numpy as np | ||
9 | + | ||
10 | +def avg_feature_vector(text, model, num_features, index2word_set): | ||
11 | + words = parse_keywords(text) | ||
12 | + feature_vec = np.zeros((num_features,), dtype='float32') | ||
13 | + n_words = 0 | ||
14 | + for word in words: | ||
15 | + if word in index2word_set: | ||
16 | + n_words += 1 | ||
17 | + feature_vec = np.add(feature_vec, model[word]) | ||
18 | + if (n_words > 0): | ||
19 | + feature_vec = np.divide(feature_vec, n_words) | ||
20 | + return feature_vec | ||
21 | + | ||
22 | +def compare(c2v_model, model, dir1, dir2): | ||
23 | + files = [f for f in readdir(dir1) if is_extension(f, 'py')] | ||
24 | + | ||
25 | + plt.ylabel('cos_sim') | ||
26 | + m = 10 | ||
27 | + Mx = 0 | ||
28 | + idx = 0 | ||
29 | + L = len(files) | ||
30 | + data = [] | ||
31 | + index2word_set = set(c2v_model.index_to_key) | ||
32 | + | ||
33 | + for f in files: | ||
34 | + print(idx,"/",L) | ||
35 | + f2 = dir2 + f.split(dir1)[1] | ||
36 | + | ||
37 | + text1 = readAll(f) | ||
38 | + text2 = readAll(f2) | ||
39 | + | ||
40 | + input1 = avg_feature_vector(text1, c2v_model, 384, index2word_set) | ||
41 | + input2 = avg_feature_vector(text2, c2v_model, 384, index2word_set) | ||
42 | + | ||
43 | + data.append([[input1], [input2]]) | ||
44 | + idx += 1 | ||
45 | + | ||
46 | + result = model.predict(data) | ||
47 | + print(result) | ||
48 | + | ||
49 | +vectors_text_path = 'data/targets.txt' | ||
50 | +c2v_model = KeyedVectors.load_word2vec_format(vectors_text_path, binary=False) | ||
51 | +model = load_model(config.MODEL_PATH) | ||
52 | + | ||
53 | +# Usage | ||
54 | +# compare(c2v_model, model, 'data/refined', 'data/shuffled') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/test.py
0 → 100644
1 | +import config | ||
2 | +from dataset import load_dataset | ||
3 | +from tensorflow.keras.models import load_model | ||
4 | +import tensorflow as tf | ||
5 | + | ||
6 | +pairData, pairLabels = load_dataset(config.DATASET_PATH) | ||
7 | +print("Loaded Dataset") | ||
8 | + | ||
9 | +X1 = pairData[:, 0].tolist() | ||
10 | +X2 = pairData[:, 1].tolist() | ||
11 | +Label = pairLabels[:].tolist() | ||
12 | + | ||
13 | +X1 = tf.convert_to_tensor(X1) | ||
14 | +X2 = tf.convert_to_tensor(X2) | ||
15 | +Label = tf.convert_to_tensor(Label) | ||
16 | + | ||
17 | +model = load_model(config.MODEL_PATH) | ||
18 | + | ||
19 | +result = model.evaluate([X1, X2], Label, batch_size=64) | ||
20 | +print("test loss, test acc:", result) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/train.py
0 → 100644
1 | +from tokenize import Token | ||
2 | +from utils import plot_training | ||
3 | +import config | ||
4 | +import os | ||
5 | +import numpy as np | ||
6 | +import random | ||
7 | +import tensorflow as tf | ||
8 | +from dataset import load_dataset, load_embedding, make_dataset_small_v2, save_dataset | ||
9 | +from model import build_siamese_model | ||
10 | +from tensorflow.keras.models import load_model | ||
11 | +from tensorflow.keras.callbacks import Callback | ||
12 | + | ||
13 | +# load dataset | ||
14 | +if os.path.exists(config.DATASET_PATH): | ||
15 | + pairData, pairLabels = load_dataset(config.DATASET_PATH) | ||
16 | + print("Loaded Dataset") | ||
17 | +else: | ||
18 | + print("Generating Dataset...") | ||
19 | + pairData, pairLabels = make_dataset_small(config.VECTOR_PATH) | ||
20 | + save_dataset(config.DATASET_PATH, pairData, pairLabels) | ||
21 | + print("Saved Dataset") | ||
22 | + | ||
23 | +# build model | ||
24 | + | ||
25 | +if not os.path.exists(config.MODEL_PATH): | ||
26 | + print("Loading Embedding Vectors...") | ||
27 | + vocab_size, embedding_matrix = load_embedding(config.EMBEDDING_PATH) | ||
28 | + print("Building Models...") | ||
29 | + model = build_siamese_model(embedding_matrix, 384) | ||
30 | +else: | ||
31 | + model = load_model(config.MODEL_PATH) | ||
32 | + | ||
33 | +# train model | ||
34 | + | ||
35 | +X1 = pairData[:, 0].tolist() | ||
36 | +X2 = pairData[:, 1].tolist() | ||
37 | +Label = pairLabels[:].tolist() | ||
38 | + | ||
39 | +X1 = tf.convert_to_tensor(X1) | ||
40 | +X2 = tf.convert_to_tensor(X2) | ||
41 | +Label = tf.convert_to_tensor(Label) | ||
42 | + | ||
43 | +Length = int(len(X1) * 0.7) | ||
44 | +trainX1, testX1 = X1[:Length], X1[-Length:] | ||
45 | +trainX2, testX2 = X2[:Length], X2[-Length:] | ||
46 | +trainY, testY = Label[:Length], Label[-Length:] | ||
47 | + | ||
48 | +print("Training Model...") | ||
49 | + | ||
50 | +history = model.fit([trainX1, trainX2], trainY, batch_size=config.BATCH_SIZE, epochs=config.EPOCHS, | ||
51 | + validation_data=([testX1, testX2], testY)) | ||
52 | + | ||
53 | + | ||
54 | +print("Saving Model...") | ||
55 | +model.save(config.MODEL_PATH) | ||
56 | +print("Saved Model") | ||
57 | + | ||
58 | +plot_training(history, config.PLOT_PATH) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/siamese/utils.py
0 → 100644
1 | +import os | ||
2 | +import re | ||
3 | +import matplotlib.pyplot as plt | ||
4 | + | ||
5 | +multi_line_comments = ["'''", '"""'] | ||
6 | + | ||
7 | +def remove_string(line): | ||
8 | + strIn = False | ||
9 | + strCh = None | ||
10 | + result = '' | ||
11 | + i = 0 | ||
12 | + L = len(line) | ||
13 | + | ||
14 | + while i < L: | ||
15 | + if i + 3 < L: | ||
16 | + if line[i:i+3] in multi_line_comments: | ||
17 | + if not strIn: | ||
18 | + strIn = True | ||
19 | + strCh = line[i:i+3] | ||
20 | + elif line[i:i+3] == strCh: | ||
21 | + strIn = False | ||
22 | + | ||
23 | + i += 2 | ||
24 | + continue | ||
25 | + | ||
26 | + c = line[i] | ||
27 | + i += 1 | ||
28 | + | ||
29 | + if c == '\'' or c == '\"': | ||
30 | + if not strIn: | ||
31 | + strIn = True | ||
32 | + strCh = c | ||
33 | + elif c == strCh: | ||
34 | + strIn = False | ||
35 | + continue | ||
36 | + | ||
37 | + if strIn: | ||
38 | + continue | ||
39 | + | ||
40 | + result += c | ||
41 | + | ||
42 | + return result | ||
43 | + | ||
44 | +def is_extension(f, ext): | ||
45 | + return os.path.splitext(f)[1][1:] == ext | ||
46 | + | ||
47 | +def _readdir_r(dirpath): # readdir for recursive | ||
48 | + ret = [] | ||
49 | + for f in os.listdir(dirpath): | ||
50 | + ret.append(os.path.join(dirpath, f)) | ||
51 | + | ||
52 | + return ret | ||
53 | + | ||
54 | +def readdir(path): # read files from the directory | ||
55 | + pathList = [path] | ||
56 | + result = [] | ||
57 | + i = 0 | ||
58 | + | ||
59 | + while i < len(pathList): | ||
60 | + f = pathList[i] | ||
61 | + if os.path.isdir(f): | ||
62 | + pathList += _readdir_r(f) | ||
63 | + else: | ||
64 | + result.append(f) | ||
65 | + | ||
66 | + i += 1 | ||
67 | + | ||
68 | + return result | ||
69 | + | ||
70 | +def readAll(path): | ||
71 | + f = open(path, 'r', encoding='utf8') | ||
72 | + ret = f.read() | ||
73 | + f.close() | ||
74 | + return ret | ||
75 | + | ||
76 | +def readLines(path): | ||
77 | + f = open(path, 'r', encoding='utf8') | ||
78 | + ret = f.readlines() | ||
79 | + f.close() | ||
80 | + return ret | ||
81 | + | ||
82 | +def plot_training(H, plotPath): | ||
83 | + plt.style.use("ggplot") | ||
84 | + plt.figure() | ||
85 | + plt.plot(H.history["loss"], label="train_loss") | ||
86 | + plt.plot(H.history["val_loss"], label="val_loss") | ||
87 | + plt.plot(H.history["accuracy"], label="train_acc") | ||
88 | + plt.plot(H.history["val_accuracy"], label="val_acc") | ||
89 | + plt.title("Training Loss and Accuracy") | ||
90 | + plt.xlabel("Epoch #") | ||
91 | + plt.ylabel("Loss/Accuracy") | ||
92 | + plt.legend(loc="lower left") | ||
93 | + plt.savefig(plotPath) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/similarity_plotter/code2vec_tester.py
0 → 100644
1 | +from gensim.models import KeyedVectors | ||
2 | +import text2vec | ||
3 | +import random | ||
4 | +from utils import * | ||
5 | +import matplotlib.pyplot as plt | ||
6 | + | ||
7 | +vectors_text_path = 'data/targets.txt' # w2v output file from model | ||
8 | +model = KeyedVectors.load_word2vec_format(vectors_text_path, binary=False) | ||
9 | + | ||
10 | +def compare(dir1, dir2): | ||
11 | + files = [f for f in readdir(dir1) if is_extension(f, 'py')] | ||
12 | + | ||
13 | + plt.ylabel('cos_sim') | ||
14 | + m = 10 | ||
15 | + Mx = 0 | ||
16 | + idx = 0 | ||
17 | + L = len(files) | ||
18 | + | ||
19 | + for f in files: | ||
20 | + print(idx,"/",L) | ||
21 | + f2 = dir2 + f.split(dir1)[1] | ||
22 | + | ||
23 | + text1 = readAll(f) | ||
24 | + text2 = readAll(f2) | ||
25 | + | ||
26 | + similarity = text2vec.get_similarity(text1, text2, model, 384) | ||
27 | + m = min(m, similarity) | ||
28 | + Mx = max(Mx, similarity) | ||
29 | + plt.plot(idx, similarity, 'r.') | ||
30 | + idx += 1 | ||
31 | + | ||
32 | + print("min:", m, "max:", Mx) | ||
33 | + plt.show() | ||
34 | + | ||
35 | +def compare2(path): # for merged dataset | ||
36 | + pairs = read_file(path + '/log.txt') # log file format: path_merged path_source1 path_source2 | ||
37 | + | ||
38 | + plt.ylabel('cos_sim') | ||
39 | + m = 10 | ||
40 | + Mx = 0 | ||
41 | + idx = 0 | ||
42 | + L = len(pairs) | ||
43 | + s1 = [] | ||
44 | + s2 = [] | ||
45 | + | ||
46 | + for p in pairs: | ||
47 | + print(idx,"/",L) | ||
48 | + arr = p.split(' ') | ||
49 | + C = path + '/' + arr[0].strip() | ||
50 | + A = arr[1].strip() | ||
51 | + B = arr[2].strip() | ||
52 | + | ||
53 | + text_A = readAll(A) | ||
54 | + text_B = readAll(B) | ||
55 | + text_C = readAll(C) | ||
56 | + | ||
57 | + similarity = text2vec.get_similarity(text_A, text_C, model, 384) | ||
58 | + m = min(m, similarity) | ||
59 | + Mx = max(Mx, similarity) | ||
60 | + s1.append(similarity) | ||
61 | + | ||
62 | + similarity = text2vec.get_similarity(text_B, text_C, model, 384) | ||
63 | + m = min(m, similarity) | ||
64 | + Mx = max(Mx, similarity) | ||
65 | + s2.append(similarity) | ||
66 | + idx += 1 | ||
67 | + | ||
68 | + print("min:", m, "max:", Mx) | ||
69 | + plt.plot(s1, 'r.') | ||
70 | + plt.waitforbuttonpress() | ||
71 | + | ||
72 | + plt.cla() | ||
73 | + plt.plot(s2, 'b.') | ||
74 | + plt.show() | ||
75 | + | ||
76 | +def compare3(dir): # for original dataset compare. (n^2 here. beware of long processing | ||
77 | + files = [f for f in readdir(dir) if is_extension(f, 'py')] | ||
78 | + | ||
79 | + plt.ylabel('cos_sim') | ||
80 | + m = 10 | ||
81 | + Mx = 0 | ||
82 | + idx = 0 | ||
83 | + L = len(files) | ||
84 | + data = [] | ||
85 | + | ||
86 | + for f in files: | ||
87 | + print(idx,"/",L) | ||
88 | + | ||
89 | + text = readAll(f) | ||
90 | + data.append(text) | ||
91 | + idx += 1 | ||
92 | + | ||
93 | + for i in range(L): | ||
94 | + print(i) | ||
95 | + j = i | ||
96 | + if i == 0: | ||
97 | + continue | ||
98 | + while j == i: | ||
99 | + j = random.choice(list(range(i))) | ||
100 | + | ||
101 | + similarity = text2vec.get_similarity(data[i], data[j], model, 384) | ||
102 | + m = min(m, similarity) | ||
103 | + Mx = max(Mx, similarity) | ||
104 | + plt.plot(i, similarity, 'r.') | ||
105 | + | ||
106 | + print("min:", m, "max:", Mx) | ||
107 | + plt.show() | ||
108 | + | ||
109 | +# Usage | ||
110 | +# compare('data/refined', 'data/obfuscated2') | ||
111 | +# compare2('data/merged') | ||
112 | +# compare3('data/refined') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/similarity_plotter/file_parser.py
0 → 100644
1 | +import re | ||
2 | +from utils import remove_string | ||
3 | + | ||
4 | +def parse_keywords(line): | ||
5 | + line = line.strip() | ||
6 | + line = remove_string(line) | ||
7 | + result = '' | ||
8 | + | ||
9 | + for c in line: | ||
10 | + if re.match('[A-Za-z_@0-9]', c): | ||
11 | + result += c | ||
12 | + else: | ||
13 | + result += ' ' | ||
14 | + | ||
15 | + return result.split(' ') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/similarity_plotter/text2vec.py
0 → 100644
1 | +from file_parser import parse_keywords | ||
2 | +import numpy as np | ||
3 | +from scipy import spatial | ||
4 | + | ||
5 | +def avg_feature_vector(text, model, num_features, index2word_set): | ||
6 | + words = parse_keywords(text) | ||
7 | + feature_vec = np.zeros((num_features, ), dtype='float32') | ||
8 | + n_words = 0 | ||
9 | + for word in words: | ||
10 | + if word in index2word_set: | ||
11 | + n_words += 1 | ||
12 | + feature_vec = np.add(feature_vec, model[word]) | ||
13 | + if (n_words > 0): | ||
14 | + feature_vec = np.divide(feature_vec, n_words) | ||
15 | + return feature_vec | ||
16 | + | ||
17 | +def get_similarity(text1, text2, model, num_features): | ||
18 | + index2word_set = set(model.index_to_key) | ||
19 | + s1 = avg_feature_vector(text1, model, num_features, index2word_set) | ||
20 | + s2 = avg_feature_vector(text2, model, num_features, index2word_set) | ||
21 | + return abs(1 - spatial.distance.cosine(s1, s2)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/similarity_plotter/utils.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +multi_line_comments = ["'''", '"""'] | ||
4 | + | ||
5 | +def remove_string(line): | ||
6 | + strIn = False | ||
7 | + strCh = None | ||
8 | + result = '' | ||
9 | + i = 0 | ||
10 | + L = len(line) | ||
11 | + | ||
12 | + while i < L: | ||
13 | + if i + 3 < L: | ||
14 | + if line[i:i+3] in multi_line_comments: | ||
15 | + if not strIn: | ||
16 | + strIn = True | ||
17 | + strCh = line[i:i+3] | ||
18 | + elif line[i:i+3] == strCh: | ||
19 | + strIn = False | ||
20 | + | ||
21 | + i += 2 | ||
22 | + continue | ||
23 | + | ||
24 | + c = line[i] | ||
25 | + i += 1 | ||
26 | + | ||
27 | + if c == '\'' or c == '\"': | ||
28 | + if not strIn: | ||
29 | + strIn = True | ||
30 | + strCh = c | ||
31 | + elif c == strCh: | ||
32 | + strIn = False | ||
33 | + continue | ||
34 | + | ||
35 | + if strIn: | ||
36 | + continue | ||
37 | + | ||
38 | + result += c | ||
39 | + | ||
40 | + return result | ||
41 | + | ||
42 | +def using_multi_string(line, index): | ||
43 | + line = line.strip() | ||
44 | + for comment in multi_line_comments: | ||
45 | + if line.find(comment, index) > 0: | ||
46 | + return True | ||
47 | + return False | ||
48 | + | ||
49 | +def remove_unnecessary_comments(lines): | ||
50 | + # Warning : cannot detect all multi-line comments, because it exactly is multi-line string. | ||
51 | + | ||
52 | + #TODO: multi line string parser will not work well when using strings (and comments, also) more than one. | ||
53 | + # ex) a = ''' d ''' + ''' | ||
54 | + # abc ''' + ''' | ||
55 | + # x''' | ||
56 | + | ||
57 | + result = [] | ||
58 | + multi_line = False | ||
59 | + multi_string = False | ||
60 | + strCh = None | ||
61 | + | ||
62 | + for line in lines: | ||
63 | + find_str_index = 0 | ||
64 | + if multi_string: | ||
65 | + if strCh in line: | ||
66 | + find_str_index = line.find(strCh) + 3 | ||
67 | + multi_string = False | ||
68 | + strCh = None | ||
69 | + | ||
70 | + result.append(line) | ||
71 | + continue | ||
72 | + | ||
73 | + if multi_line: # parsing multi-line comments | ||
74 | + if strCh in line: | ||
75 | + multi_line = False | ||
76 | + strCh = None | ||
77 | + continue | ||
78 | + | ||
79 | + if using_multi_string(line, find_str_index): | ||
80 | + i1 = line.find(multi_line_comments[0]) | ||
81 | + i2 = line.find(multi_line_comments[1]) | ||
82 | + | ||
83 | + if i1 < 0: | ||
84 | + i1 = len(line) + 1 | ||
85 | + if i2 < 0: | ||
86 | + i2 = len(line) + 1 | ||
87 | + | ||
88 | + if i1 < i2: | ||
89 | + strCh = multi_line_comments[0] | ||
90 | + else: | ||
91 | + strCh = multi_line_comments[1] | ||
92 | + | ||
93 | + result.append(line) | ||
94 | + if line.count(strCh) % 2 != 0: | ||
95 | + multi_string = True | ||
96 | + continue | ||
97 | + | ||
98 | + code = line.strip() | ||
99 | + | ||
100 | + if code[:3] in multi_line_comments: # detect in-out of multi-line comments | ||
101 | + if code.count(code[:3]) % 2 != 0: # comment count in odd numbers (start or end of comment is in the line) | ||
102 | + multi_line = True | ||
103 | + strCh = code[:3] | ||
104 | + continue | ||
105 | + | ||
106 | + comment_index = line.find('#') | ||
107 | + if comment_index >= 0: # one line comment found | ||
108 | + line = line[:comment_index] | ||
109 | + line = line.rstrip() # remove rightmost spaces | ||
110 | + | ||
111 | + if len(line) == 0: # no code in this line | ||
112 | + continue | ||
113 | + | ||
114 | + result.append(line) # add to results | ||
115 | + | ||
116 | + return result | ||
117 | + | ||
118 | +def is_extension(f, ext): | ||
119 | + return os.path.splitext(f)[1][1:] == ext | ||
120 | + | ||
121 | +def _readdir_r(dirpath): # readdir for recursive | ||
122 | + ret = [] | ||
123 | + for f in os.listdir(dirpath): | ||
124 | + ret.append(os.path.join(dirpath, f)) | ||
125 | + | ||
126 | + return ret | ||
127 | + | ||
128 | +def readdir(path): # read files from the directory | ||
129 | + pathList = [path] | ||
130 | + result = [] | ||
131 | + i = 0 | ||
132 | + | ||
133 | + while i < len(pathList): | ||
134 | + f = pathList[i] | ||
135 | + if os.path.isdir(f): | ||
136 | + pathList += _readdir_r(f) | ||
137 | + else: | ||
138 | + result.append(f) | ||
139 | + | ||
140 | + i += 1 | ||
141 | + | ||
142 | + return result | ||
143 | + | ||
144 | +def read_file(path): | ||
145 | + f = open(path, 'r', encoding='utf8') | ||
146 | + ret = f.readlines() | ||
147 | + f.close() | ||
148 | + return ret | ||
149 | + | ||
150 | +def write_file(path, lines): | ||
151 | + f = open(path, 'w', encoding='utf8') | ||
152 | + | ||
153 | + for line in lines: | ||
154 | + if '\n' in line: | ||
155 | + f.write(line) | ||
156 | + else: | ||
157 | + f.write(line + '\n') | ||
158 | + f.close() | ||
159 | + | ||
160 | +def readAll(path): | ||
161 | + f = open(path, 'r', encoding='utf8') | ||
162 | + ret = f.read() | ||
163 | + f.close() | ||
164 | + return ret | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
reports/최종보고서.pdf
0 → 100644
No preview for this file type
-
Please register or login to post a comment