Showing
1 changed file
with
449 additions
and
439 deletions
1 | { | 1 | { |
2 | - "nbformat": 4, | 2 | + "nbformat": 4, |
3 | - "nbformat_minor": 0, | 3 | + "nbformat_minor": 0, |
4 | - "metadata": { | 4 | + "metadata": { |
5 | - "colab": { | 5 | + "colab": { |
6 | - "name": "commit-autosuggestions.ipynb", | 6 | + "name": "commit-autosuggestions.ipynb", |
7 | - "provenance": [], | 7 | + "provenance": [], |
8 | - "collapsed_sections": [], | 8 | + "collapsed_sections": [], |
9 | - "toc_visible": true | 9 | + "toc_visible": true |
10 | - }, | ||
11 | - "kernelspec": { | ||
12 | - "name": "python3", | ||
13 | - "display_name": "Python 3" | ||
14 | - }, | ||
15 | - "accelerator": "GPU" | ||
16 | }, | 10 | }, |
17 | - "cells": [ | 11 | + "kernelspec": { |
18 | - { | 12 | + "name": "python3", |
19 | - "cell_type": "markdown", | 13 | + "display_name": "Python 3" |
20 | - "metadata": { | 14 | + }, |
21 | - "id": "DZ7rFp2gzuNS" | 15 | + "accelerator": "GPU" |
22 | - }, | 16 | + }, |
23 | - "source": [ | 17 | + "cells": [ |
24 | - "## Start commit-autosuggestions server\n", | 18 | + { |
25 | - "Running flask app server in Google Colab for people without GPU" | 19 | + "cell_type": "markdown", |
26 | - ] | 20 | + "metadata": { |
27 | - }, | 21 | + "id": "DZ7rFp2gzuNS" |
28 | - { | 22 | + }, |
29 | - "cell_type": "markdown", | 23 | + "source": [ |
30 | - "metadata": { | 24 | + "## Start commit-autosuggestions server\n", |
31 | - "id": "d8Lyin2I3wHq" | 25 | + "Running flask app server in Google Colab for people without GPU" |
32 | - }, | 26 | + ] |
33 | - "source": [ | 27 | + }, |
34 | - "#### Clone github repository" | 28 | + { |
35 | - ] | 29 | + "cell_type": "markdown", |
36 | - }, | 30 | + "metadata": { |
37 | - { | 31 | + "id": "d8Lyin2I3wHq" |
38 | - "cell_type": "code", | 32 | + }, |
39 | - "metadata": { | 33 | + "source": [ |
40 | - "id": "e_cu9igvzjcs" | 34 | + "#### Clone github repository" |
41 | - }, | 35 | + ] |
42 | - "source": [ | 36 | + }, |
43 | - "!git clone https://github.com/graykode/commit-autosuggestions.git\n", | 37 | + { |
44 | - "%cd commit-autosuggestions\n", | 38 | + "cell_type": "code", |
45 | - "!pip install -r requirements.txt" | 39 | + "metadata": { |
46 | - ], | 40 | + "id": "e_cu9igvzjcs" |
47 | - "execution_count": null, | 41 | + }, |
48 | - "outputs": [] | 42 | + "source": [ |
49 | - }, | 43 | + "!git clone https://github.com/graykode/commit-autosuggestions.git\n", |
50 | - { | 44 | + "%cd commit-autosuggestions\n", |
51 | - "cell_type": "markdown", | 45 | + "!pip install -r requirements.txt" |
52 | - "metadata": { | 46 | + ], |
53 | - "id": "PFKn5QZr0dQx" | 47 | + "execution_count": null, |
54 | - }, | 48 | + "outputs": [] |
55 | - "source": [ | 49 | + }, |
56 | - "#### Download model weights\n", | 50 | + { |
57 | - "\n", | 51 | + "cell_type": "markdown", |
58 | - "Download the two weights of model from the google drive through the gdown module.\n", | 52 | + "metadata": { |
59 | - "1. [Added model](https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4) : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", | 53 | + "id": "PFKn5QZr0dQx" |
60 | - "2. [Diff model](https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m) : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code." | 54 | + }, |
61 | - ] | 55 | + "source": [ |
62 | - }, | 56 | + "#### Download model weights\n", |
63 | - { | 57 | + "\n", |
64 | - "cell_type": "code", | 58 | + "Download the two weights of model from the google drive through the gdown module.\n", |
65 | - "metadata": { | 59 | + "1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", |
66 | - "id": "P9-EBpxt0Dp0" | 60 | + "2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n", |
67 | - }, | 61 | + "\n", |
68 | - "source": [ | 62 | + "Download pre-trained weight\n", |
69 | - "!pip install gdown \\\n", | 63 | + "\n", |
70 | - " && gdown \"https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4\" -O weight/added/pytorch_model.bin \\\n", | 64 | + "Language | Added | Diff\n", |
71 | - " && gdown \"https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\" -O weight/diff/pytorch_model.bin" | 65 | + "--- | --- | ---\n", |
72 | - ], | 66 | + "python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n", |
73 | - "execution_count": null, | 67 | + "javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf" |
74 | - "outputs": [] | 68 | + ] |
75 | - }, | 69 | + }, |
76 | - { | 70 | + { |
77 | - "cell_type": "markdown", | 71 | + "cell_type": "code", |
78 | - "metadata": { | 72 | + "metadata": { |
79 | - "id": "org4Gqdv3iUu" | 73 | + "id": "P9-EBpxt0Dp0" |
80 | - }, | 74 | + }, |
81 | - "source": [ | 75 | + "source": [ |
82 | - "#### ngrok setting with flask\n", | 76 | + "ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n", |
83 | - "\n", | 77 | + "DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n", |
84 | - "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." | 78 | + "\n", |
85 | - ] | 79 | + "!pip install gdown \\\n", |
86 | - }, | 80 | + " && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n", |
87 | - { | 81 | + " && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin" |
88 | - "cell_type": "code", | 82 | + ], |
89 | - "metadata": { | 83 | + "execution_count": null, |
90 | - "id": "lZA3kuuG1Crj" | 84 | + "outputs": [] |
91 | - }, | 85 | + }, |
92 | - "source": [ | 86 | + { |
93 | - "!pip install flask-ngrok" | 87 | + "cell_type": "markdown", |
94 | - ], | 88 | + "metadata": { |
95 | - "execution_count": null, | 89 | + "id": "org4Gqdv3iUu" |
96 | - "outputs": [] | 90 | + }, |
97 | - }, | 91 | + "source": [ |
98 | - { | 92 | + "#### ngrok setting with flask\n", |
99 | - "cell_type": "markdown", | 93 | + "\n", |
100 | - "metadata": { | 94 | + "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." |
101 | - "id": "hR78FRCMcqrZ" | 95 | + ] |
102 | - }, | 96 | + }, |
103 | - "source": [ | 97 | + { |
104 | - "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", | 98 | + "cell_type": "code", |
105 | - "\n" | 99 | + "metadata": { |
106 | - ] | 100 | + "id": "lZA3kuuG1Crj" |
107 | - }, | 101 | + }, |
108 | - { | 102 | + "source": [ |
109 | - "cell_type": "code", | 103 | + "!pip install flask-ngrok" |
110 | - "metadata": { | 104 | + ], |
111 | - "id": "L_mInbOKcoc2" | 105 | + "execution_count": null, |
112 | - }, | 106 | + "outputs": [] |
113 | - "source": [ | 107 | + }, |
114 | - "# Paste your authtoken here in quotes\n", | 108 | + { |
115 | - "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" | 109 | + "cell_type": "markdown", |
116 | - ], | 110 | + "metadata": { |
117 | - "execution_count": null, | 111 | + "id": "hR78FRCMcqrZ" |
118 | - "outputs": [] | 112 | + }, |
119 | - }, | 113 | + "source": [ |
120 | - { | 114 | + "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", |
121 | - "cell_type": "markdown", | 115 | + "\n" |
122 | - "metadata": { | 116 | + ] |
123 | - "id": "QwCN4YFUc0M8" | 117 | + }, |
124 | - }, | 118 | + { |
125 | - "source": [ | 119 | + "cell_type": "code", |
126 | - "Set your region\n", | 120 | + "metadata": { |
127 | - "\n", | 121 | + "id": "L_mInbOKcoc2" |
128 | - "Code | Region\n", | 122 | + }, |
129 | - "--- | ---\n", | 123 | + "source": [ |
130 | - "us | United States\n", | 124 | + "# Paste your authtoken here in quotes\n", |
131 | - "eu | Europe\n", | 125 | + "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" |
132 | - "ap | Asia/Pacific\n", | 126 | + ], |
133 | - "au | Australia\n", | 127 | + "execution_count": null, |
134 | - "sa | South America\n", | 128 | + "outputs": [] |
135 | - "jp | Japan\n", | 129 | + }, |
136 | - "in | India" | 130 | + { |
137 | - ] | 131 | + "cell_type": "markdown", |
138 | - }, | 132 | + "metadata": { |
139 | - { | 133 | + "id": "QwCN4YFUc0M8" |
140 | - "cell_type": "code", | 134 | + }, |
141 | - "metadata": { | 135 | + "source": [ |
142 | - "id": "p4LSNN2xc0dQ" | 136 | + "Set your region\n", |
143 | - }, | 137 | + "\n", |
144 | - "source": [ | 138 | + "Code | Region\n", |
145 | - "# Set your region here in quotes\n", | 139 | + "--- | ---\n", |
146 | - "region = \"jp\"\n", | 140 | + "us | United States\n", |
147 | - "\n", | 141 | + "eu | Europe\n", |
148 | - "# Input and output ports for communication\n", | 142 | + "ap | Asia/Pacific\n", |
149 | - "local_in_port = 5000\n", | 143 | + "au | Australia\n", |
150 | - "local_out_port = 5000" | 144 | + "sa | South America\n", |
151 | - ], | 145 | + "jp | Japan\n", |
152 | - "execution_count": null, | 146 | + "in | India" |
153 | - "outputs": [] | 147 | + ] |
154 | - }, | 148 | + }, |
155 | - { | 149 | + { |
156 | - "cell_type": "code", | 150 | + "cell_type": "code", |
157 | - "metadata": { | 151 | + "metadata": { |
158 | - "id": "kg56PVrOdhi1" | 152 | + "id": "p4LSNN2xc0dQ" |
159 | - }, | 153 | + }, |
160 | - "source": [ | 154 | + "source": [ |
161 | - "config =\\\n", | 155 | + "# Set your region here in quotes\n", |
162 | - "f\"\"\"\n", | 156 | + "region = \"jp\"\n", |
163 | - "authtoken: {authtoken}\n", | 157 | + "\n", |
164 | - "region: {region}\n", | 158 | + "# Input and output ports for communication\n", |
165 | - "console_ui: False\n", | 159 | + "local_in_port = 5000\n", |
166 | - "tunnels:\n", | 160 | + "local_out_port = 5000" |
167 | - " input:\n", | 161 | + ], |
168 | - " addr: {local_in_port}\n", | 162 | + "execution_count": null, |
169 | - " proto: http \n", | 163 | + "outputs": [] |
170 | - " output:\n", | 164 | + }, |
171 | - " addr: {local_out_port}\n", | 165 | + { |
172 | - " proto: http\n", | 166 | + "cell_type": "code", |
173 | - "\"\"\"\n", | 167 | + "metadata": { |
174 | - "\n", | 168 | + "id": "kg56PVrOdhi1" |
175 | - "with open('ngrok.conf', 'w') as f:\n", | 169 | + }, |
176 | - " f.write(config)" | 170 | + "source": [ |
177 | - ], | 171 | + "config =\\\n", |
178 | - "execution_count": null, | 172 | + "f\"\"\"\n", |
179 | - "outputs": [] | 173 | + "authtoken: {authtoken}\n", |
180 | - }, | 174 | + "region: {region}\n", |
181 | - { | 175 | + "console_ui: False\n", |
182 | - "cell_type": "code", | 176 | + "tunnels:\n", |
183 | - "metadata": { | 177 | + " input:\n", |
184 | - "id": "hrWDrw_YdjIy" | 178 | + " addr: {local_in_port}\n", |
185 | - }, | 179 | + " proto: http \n", |
186 | - "source": [ | 180 | + " output:\n", |
187 | - "import time\n", | 181 | + " addr: {local_out_port}\n", |
188 | - "from subprocess import Popen, PIPE\n", | 182 | + " proto: http\n", |
189 | - "\n", | 183 | + "\"\"\"\n", |
190 | - "# (Re)Open tunnel\n", | 184 | + "\n", |
191 | - "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", | 185 | + "with open('ngrok.conf', 'w') as f:\n", |
192 | - "time.sleep(3)" | 186 | + " f.write(config)" |
193 | - ], | 187 | + ], |
194 | - "execution_count": null, | 188 | + "execution_count": null, |
195 | - "outputs": [] | 189 | + "outputs": [] |
196 | - }, | 190 | + }, |
197 | - { | 191 | + { |
198 | - "cell_type": "code", | 192 | + "cell_type": "code", |
199 | - "metadata": { | 193 | + "metadata": { |
200 | - "id": "pJgdFr0Fdjoq", | 194 | + "id": "hrWDrw_YdjIy" |
201 | - "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", | 195 | + }, |
202 | - "colab": { | 196 | + "source": [ |
203 | - "base_uri": "https://localhost:8080/" | 197 | + "import time\n", |
204 | - } | 198 | + "from subprocess import Popen, PIPE\n", |
205 | - }, | 199 | + "\n", |
206 | - "source": [ | 200 | + "# (Re)Open tunnel\n", |
207 | - "# Get tunnel addresses\n", | 201 | + "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", |
208 | - "try:\n", | 202 | + "time.sleep(3)" |
209 | - " in_addr, out_addr = get_tunnel_adresses()\n", | 203 | + ], |
210 | - " print(\"Tunnel opened\")\n", | 204 | + "execution_count": null, |
211 | - "except Exception as e:\n", | 205 | + "outputs": [] |
212 | - " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", | 206 | + }, |
213 | - " print(\"Something went wrong, reopen the tunnel\")" | 207 | + { |
214 | - ], | 208 | + "cell_type": "code", |
215 | - "execution_count": null, | 209 | + "metadata": { |
216 | - "outputs": [ | 210 | + "id": "pJgdFr0Fdjoq", |
217 | - { | 211 | + "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", |
218 | - "output_type": "stream", | 212 | + "colab": { |
219 | - "text": [ | 213 | + "base_uri": "https://localhost:8080/" |
220 | - "Opening tunnel\n", | 214 | + } |
221 | - "Something went wrong, reopen the tunnel\n" | 215 | + }, |
222 | - ], | 216 | + "source": [ |
223 | - "name": "stdout" | 217 | + "# Get tunnel addresses\n", |
224 | - } | 218 | + "try:\n", |
225 | - ] | 219 | + " in_addr, out_addr = get_tunnel_adresses()\n", |
226 | - }, | 220 | + " print(\"Tunnel opened\")\n", |
227 | - { | 221 | + "except Exception as e:\n", |
228 | - "cell_type": "markdown", | 222 | + " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", |
229 | - "metadata": { | 223 | + " print(\"Something went wrong, reopen the tunnel\")" |
230 | - "id": "cEZ-O0wz74OJ" | 224 | + ], |
231 | - }, | 225 | + "execution_count": null, |
232 | - "source": [ | 226 | + "outputs": [ |
233 | - "#### Run you server!" | ||
234 | - ] | ||
235 | - }, | ||
236 | - { | ||
237 | - "cell_type": "code", | ||
238 | - "metadata": { | ||
239 | - "id": "7PRkeYTL8Y_6" | ||
240 | - }, | ||
241 | - "source": [ | ||
242 | - "import os\n", | ||
243 | - "import torch\n", | ||
244 | - "import argparse\n", | ||
245 | - "from tqdm import tqdm\n", | ||
246 | - "import torch.nn as nn\n", | ||
247 | - "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", | ||
248 | - "from transformers import (RobertaConfig, RobertaTokenizer)\n", | ||
249 | - "\n", | ||
250 | - "from commit.model import Seq2Seq\n", | ||
251 | - "from commit.utils import (Example, convert_examples_to_features)\n", | ||
252 | - "from commit.model.diff_roberta import RobertaModel\n", | ||
253 | - "\n", | ||
254 | - "from flask import Flask, jsonify, request\n", | ||
255 | - "\n", | ||
256 | - "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}" | ||
257 | - ], | ||
258 | - "execution_count": null, | ||
259 | - "outputs": [] | ||
260 | - }, | ||
261 | - { | ||
262 | - "cell_type": "code", | ||
263 | - "metadata": { | ||
264 | - "id": "CiJKucX17qb4" | ||
265 | - }, | ||
266 | - "source": [ | ||
267 | - "def get_model(model_class, config, tokenizer, mode):\n", | ||
268 | - " encoder = model_class(config=config)\n", | ||
269 | - " decoder_layer = nn.TransformerDecoderLayer(\n", | ||
270 | - " d_model=config.hidden_size, nhead=config.num_attention_heads\n", | ||
271 | - " )\n", | ||
272 | - " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", | ||
273 | - " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n", | ||
274 | - " beam_size=args.beam_size, max_length=args.max_target_length,\n", | ||
275 | - " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n", | ||
276 | - "\n", | ||
277 | - " assert args.load_model_path\n", | ||
278 | - " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n", | ||
279 | - "\n", | ||
280 | - " model.load_state_dict(\n", | ||
281 | - " torch.load(\n", | ||
282 | - " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n", | ||
283 | - " map_location=torch.device('cpu')\n", | ||
284 | - " ),\n", | ||
285 | - " strict=False\n", | ||
286 | - " )\n", | ||
287 | - " return model\n", | ||
288 | - "\n", | ||
289 | - "def get_features(examples):\n", | ||
290 | - " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n", | ||
291 | - " all_source_ids = torch.tensor(\n", | ||
292 | - " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
293 | - " )\n", | ||
294 | - " all_source_mask = torch.tensor(\n", | ||
295 | - " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
296 | - " )\n", | ||
297 | - " all_patch_ids = torch.tensor(\n", | ||
298 | - " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
299 | - " )\n", | ||
300 | - " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n", | ||
301 | - "\n", | ||
302 | - "def create_app():\n", | ||
303 | - " @app.route('/')\n", | ||
304 | - " def index():\n", | ||
305 | - " return jsonify(hello=\"world\")\n", | ||
306 | - "\n", | ||
307 | - " @app.route('/added', methods=['POST'])\n", | ||
308 | - " def added():\n", | ||
309 | - " if request.method == 'POST':\n", | ||
310 | - " payload = request.get_json()\n", | ||
311 | - " example = [\n", | ||
312 | - " Example(\n", | ||
313 | - " idx=payload['idx'],\n", | ||
314 | - " added=payload['added'],\n", | ||
315 | - " deleted=payload['deleted'],\n", | ||
316 | - " target=None\n", | ||
317 | - " )\n", | ||
318 | - " ]\n", | ||
319 | - " message = inference(model=args.added_model, data=get_features(example))\n", | ||
320 | - " return jsonify(idx=payload['idx'], message=message)\n", | ||
321 | - "\n", | ||
322 | - " @app.route('/diff', methods=['POST'])\n", | ||
323 | - " def diff():\n", | ||
324 | - " if request.method == 'POST':\n", | ||
325 | - " payload = request.get_json()\n", | ||
326 | - " example = [\n", | ||
327 | - " Example(\n", | ||
328 | - " idx=payload['idx'],\n", | ||
329 | - " added=payload['added'],\n", | ||
330 | - " deleted=payload['deleted'],\n", | ||
331 | - " target=None\n", | ||
332 | - " )\n", | ||
333 | - " ]\n", | ||
334 | - " message = inference(model=args.diff_model, data=get_features(example))\n", | ||
335 | - " return jsonify(idx=payload['idx'], message=message)\n", | ||
336 | - "\n", | ||
337 | - " @app.route('/tokenizer', methods=['POST'])\n", | ||
338 | - " def tokenizer():\n", | ||
339 | - " if request.method == 'POST':\n", | ||
340 | - " payload = request.get_json()\n", | ||
341 | - " tokens = args.tokenizer.tokenize(payload['code'])\n", | ||
342 | - " return jsonify(tokens=tokens)\n", | ||
343 | - "\n", | ||
344 | - " return app\n", | ||
345 | - "\n", | ||
346 | - "def inference(model, data):\n", | ||
347 | - " # Calculate bleu\n", | ||
348 | - " eval_sampler = SequentialSampler(data)\n", | ||
349 | - " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n", | ||
350 | - "\n", | ||
351 | - " model.eval()\n", | ||
352 | - " p=[]\n", | ||
353 | - " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n", | ||
354 | - " batch = tuple(t.to(args.device) for t in batch)\n", | ||
355 | - " source_ids, source_mask, patch_ids = batch\n", | ||
356 | - " with torch.no_grad():\n", | ||
357 | - " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n", | ||
358 | - " for pred in preds:\n", | ||
359 | - " t = pred[0].cpu().numpy()\n", | ||
360 | - " t = list(t)\n", | ||
361 | - " if 0 in t:\n", | ||
362 | - " t = t[:t.index(0)]\n", | ||
363 | - " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n", | ||
364 | - " p.append(text)\n", | ||
365 | - " return p" | ||
366 | - ], | ||
367 | - "execution_count": null, | ||
368 | - "outputs": [] | ||
369 | - }, | ||
370 | - { | ||
371 | - "cell_type": "markdown", | ||
372 | - "metadata": { | ||
373 | - "id": "Esf4r-Ai8cG3" | ||
374 | - }, | ||
375 | - "source": [ | ||
376 | - "**Set enviroment**" | ||
377 | - ] | ||
378 | - }, | ||
379 | - { | ||
380 | - "cell_type": "code", | ||
381 | - "metadata": { | ||
382 | - "id": "mR7gVmSoSUoy" | ||
383 | - }, | ||
384 | - "source": [ | ||
385 | - "import easydict \n", | ||
386 | - "\n", | ||
387 | - "args = easydict.EasyDict({\n", | ||
388 | - " 'load_model_path': 'weight/', \n", | ||
389 | - " 'model_type': 'roberta',\n", | ||
390 | - " 'config_name' : 'microsoft/codebert-base',\n", | ||
391 | - " 'tokenizer_name' : 'microsoft/codebert-base',\n", | ||
392 | - " 'max_source_length' : 512,\n", | ||
393 | - " 'max_target_length' : 128,\n", | ||
394 | - " 'beam_size' : 10,\n", | ||
395 | - " 'do_lower_case' : False,\n", | ||
396 | - " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
397 | - "})" | ||
398 | - ], | ||
399 | - "execution_count": null, | ||
400 | - "outputs": [] | ||
401 | - }, | ||
402 | - { | ||
403 | - "cell_type": "code", | ||
404 | - "metadata": { | ||
405 | - "id": "e8dk5RwvToOv" | ||
406 | - }, | ||
407 | - "source": [ | ||
408 | - "# flask_ngrok_example.py\n", | ||
409 | - "from flask_ngrok import run_with_ngrok\n", | ||
410 | - "\n", | ||
411 | - "app = Flask(__name__)\n", | ||
412 | - "run_with_ngrok(app) # Start ngrok when app is run\n", | ||
413 | - "\n", | ||
414 | - "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n", | ||
415 | - "config = config_class.from_pretrained(args.config_name)\n", | ||
416 | - "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n", | ||
417 | - "\n", | ||
418 | - "# budild model\n", | ||
419 | - "args.added_model =get_model(model_class=model_class, config=config,\n", | ||
420 | - " tokenizer=args.tokenizer, mode='added').to(args.device)\n", | ||
421 | - "args.diff_model = get_model(model_class=model_class, config=config,\n", | ||
422 | - " tokenizer=args.tokenizer, mode='diff').to(args.device)\n", | ||
423 | - "\n", | ||
424 | - "app = create_app()\n", | ||
425 | - "app.run()" | ||
426 | - ], | ||
427 | - "execution_count": null, | ||
428 | - "outputs": [] | ||
429 | - }, | ||
430 | { | 227 | { |
431 | - "cell_type": "markdown", | 228 | + "output_type": "stream", |
432 | - "metadata": { | 229 | + "text": [ |
433 | - "id": "DXkBcO_sU_VN" | 230 | + "Opening tunnel\n", |
434 | - }, | 231 | + "Something went wrong, reopen the tunnel\n" |
435 | - "source": [ | 232 | + ], |
436 | - "## Set commit configure\n", | 233 | + "name": "stdout" |
437 | - "Now, set commit configure on your local computer.\n", | ||
438 | - "```shell\n", | ||
439 | - "$ commit configure --endpoint http://********.ngrok.io\n", | ||
440 | - "```" | ||
441 | - ] | ||
442 | } | 234 | } |
443 | - ] | 235 | + ] |
236 | + }, | ||
237 | + { | ||
238 | + "cell_type": "markdown", | ||
239 | + "metadata": { | ||
240 | + "id": "cEZ-O0wz74OJ" | ||
241 | + }, | ||
242 | + "source": [ | ||
243 | + "#### Run you server!" | ||
244 | + ] | ||
245 | + }, | ||
246 | + { | ||
247 | + "cell_type": "code", | ||
248 | + "metadata": { | ||
249 | + "id": "7PRkeYTL8Y_6" | ||
250 | + }, | ||
251 | + "source": [ | ||
252 | + "import os\n", | ||
253 | + "import torch\n", | ||
254 | + "import argparse\n", | ||
255 | + "from tqdm import tqdm\n", | ||
256 | + "import torch.nn as nn\n", | ||
257 | + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", | ||
258 | + "from transformers import (RobertaConfig, RobertaTokenizer)\n", | ||
259 | + "\n", | ||
260 | + "from commit.model import Seq2Seq\n", | ||
261 | + "from commit.utils import (Example, convert_examples_to_features)\n", | ||
262 | + "from commit.model.diff_roberta import RobertaModel\n", | ||
263 | + "\n", | ||
264 | + "from flask import Flask, jsonify, request\n", | ||
265 | + "\n", | ||
266 | + "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}" | ||
267 | + ], | ||
268 | + "execution_count": null, | ||
269 | + "outputs": [] | ||
270 | + }, | ||
271 | + { | ||
272 | + "cell_type": "code", | ||
273 | + "metadata": { | ||
274 | + "id": "CiJKucX17qb4" | ||
275 | + }, | ||
276 | + "source": [ | ||
277 | + "def get_model(model_class, config, tokenizer, mode):\n", | ||
278 | + " encoder = model_class(config=config)\n", | ||
279 | + " decoder_layer = nn.TransformerDecoderLayer(\n", | ||
280 | + " d_model=config.hidden_size, nhead=config.num_attention_heads\n", | ||
281 | + " )\n", | ||
282 | + " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", | ||
283 | + " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n", | ||
284 | + " beam_size=args.beam_size, max_length=args.max_target_length,\n", | ||
285 | + " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n", | ||
286 | + "\n", | ||
287 | + " assert args.load_model_path\n", | ||
288 | + " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n", | ||
289 | + "\n", | ||
290 | + " model.load_state_dict(\n", | ||
291 | + " torch.load(\n", | ||
292 | + " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n", | ||
293 | + " map_location=torch.device('cpu')\n", | ||
294 | + " ),\n", | ||
295 | + " strict=False\n", | ||
296 | + " )\n", | ||
297 | + " return model\n", | ||
298 | + "\n", | ||
299 | + "def get_features(examples):\n", | ||
300 | + " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n", | ||
301 | + " all_source_ids = torch.tensor(\n", | ||
302 | + " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
303 | + " )\n", | ||
304 | + " all_source_mask = torch.tensor(\n", | ||
305 | + " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
306 | + " )\n", | ||
307 | + " all_patch_ids = torch.tensor(\n", | ||
308 | + " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
309 | + " )\n", | ||
310 | + " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n", | ||
311 | + "\n", | ||
312 | + "def create_app():\n", | ||
313 | + " @app.route('/')\n", | ||
314 | + " def index():\n", | ||
315 | + " return jsonify(hello=\"world\")\n", | ||
316 | + "\n", | ||
317 | + " @app.route('/added', methods=['POST'])\n", | ||
318 | + " def added():\n", | ||
319 | + " if request.method == 'POST':\n", | ||
320 | + " payload = request.get_json()\n", | ||
321 | + " example = [\n", | ||
322 | + " Example(\n", | ||
323 | + " idx=payload['idx'],\n", | ||
324 | + " added=payload['added'],\n", | ||
325 | + " deleted=payload['deleted'],\n", | ||
326 | + " target=None\n", | ||
327 | + " )\n", | ||
328 | + " ]\n", | ||
329 | + " message = inference(model=args.added_model, data=get_features(example))\n", | ||
330 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
331 | + "\n", | ||
332 | + " @app.route('/diff', methods=['POST'])\n", | ||
333 | + " def diff():\n", | ||
334 | + " if request.method == 'POST':\n", | ||
335 | + " payload = request.get_json()\n", | ||
336 | + " example = [\n", | ||
337 | + " Example(\n", | ||
338 | + " idx=payload['idx'],\n", | ||
339 | + " added=payload['added'],\n", | ||
340 | + " deleted=payload['deleted'],\n", | ||
341 | + " target=None\n", | ||
342 | + " )\n", | ||
343 | + " ]\n", | ||
344 | + " message = inference(model=args.diff_model, data=get_features(example))\n", | ||
345 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
346 | + "\n", | ||
347 | + " @app.route('/tokenizer', methods=['POST'])\n", | ||
348 | + " def tokenizer():\n", | ||
349 | + " if request.method == 'POST':\n", | ||
350 | + " payload = request.get_json()\n", | ||
351 | + " tokens = args.tokenizer.tokenize(payload['code'])\n", | ||
352 | + " return jsonify(tokens=tokens)\n", | ||
353 | + "\n", | ||
354 | + " return app\n", | ||
355 | + "\n", | ||
356 | + "def inference(model, data):\n", | ||
357 | + " # Calculate bleu\n", | ||
358 | + " eval_sampler = SequentialSampler(data)\n", | ||
359 | + " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n", | ||
360 | + "\n", | ||
361 | + " model.eval()\n", | ||
362 | + " p=[]\n", | ||
363 | + " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n", | ||
364 | + " batch = tuple(t.to(args.device) for t in batch)\n", | ||
365 | + " source_ids, source_mask, patch_ids = batch\n", | ||
366 | + " with torch.no_grad():\n", | ||
367 | + " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n", | ||
368 | + " for pred in preds:\n", | ||
369 | + " t = pred[0].cpu().numpy()\n", | ||
370 | + " t = list(t)\n", | ||
371 | + " if 0 in t:\n", | ||
372 | + " t = t[:t.index(0)]\n", | ||
373 | + " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n", | ||
374 | + " p.append(text)\n", | ||
375 | + " return p" | ||
376 | + ], | ||
377 | + "execution_count": null, | ||
378 | + "outputs": [] | ||
379 | + }, | ||
380 | + { | ||
381 | + "cell_type": "markdown", | ||
382 | + "metadata": { | ||
383 | + "id": "Esf4r-Ai8cG3" | ||
384 | + }, | ||
385 | + "source": [ | ||
386 | + "**Set enviroment**" | ||
387 | + ] | ||
388 | + }, | ||
389 | + { | ||
390 | + "cell_type": "code", | ||
391 | + "metadata": { | ||
392 | + "id": "mR7gVmSoSUoy" | ||
393 | + }, | ||
394 | + "source": [ | ||
395 | + "import easydict \n", | ||
396 | + "\n", | ||
397 | + "args = easydict.EasyDict({\n", | ||
398 | + " 'load_model_path': 'weight/', \n", | ||
399 | + " 'model_type': 'roberta',\n", | ||
400 | + " 'config_name' : 'microsoft/codebert-base',\n", | ||
401 | + " 'tokenizer_name' : 'microsoft/codebert-base',\n", | ||
402 | + " 'max_source_length' : 512,\n", | ||
403 | + " 'max_target_length' : 128,\n", | ||
404 | + " 'beam_size' : 10,\n", | ||
405 | + " 'do_lower_case' : False,\n", | ||
406 | + " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
407 | + "})" | ||
408 | + ], | ||
409 | + "execution_count": null, | ||
410 | + "outputs": [] | ||
411 | + }, | ||
412 | + { | ||
413 | + "cell_type": "code", | ||
414 | + "metadata": { | ||
415 | + "id": "e8dk5RwvToOv" | ||
416 | + }, | ||
417 | + "source": [ | ||
418 | + "# flask_ngrok_example.py\n", | ||
419 | + "from flask_ngrok import run_with_ngrok\n", | ||
420 | + "\n", | ||
421 | + "app = Flask(__name__)\n", | ||
422 | + "run_with_ngrok(app) # Start ngrok when app is run\n", | ||
423 | + "\n", | ||
424 | + "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n", | ||
425 | + "config = config_class.from_pretrained(args.config_name)\n", | ||
426 | + "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n", | ||
427 | + "\n", | ||
428 | + "# budild model\n", | ||
429 | + "args.added_model =get_model(model_class=model_class, config=config,\n", | ||
430 | + " tokenizer=args.tokenizer, mode='added').to(args.device)\n", | ||
431 | + "args.diff_model = get_model(model_class=model_class, config=config,\n", | ||
432 | + " tokenizer=args.tokenizer, mode='diff').to(args.device)\n", | ||
433 | + "\n", | ||
434 | + "app = create_app()\n", | ||
435 | + "app.run()" | ||
436 | + ], | ||
437 | + "execution_count": null, | ||
438 | + "outputs": [] | ||
439 | + }, | ||
440 | + { | ||
441 | + "cell_type": "markdown", | ||
442 | + "metadata": { | ||
443 | + "id": "DXkBcO_sU_VN" | ||
444 | + }, | ||
445 | + "source": [ | ||
446 | + "## Set commit configure\n", | ||
447 | + "Now, set commit configure on your local computer.\n", | ||
448 | + "```shell\n", | ||
449 | + "$ commit configure --endpoint http://********.ngrok.io\n", | ||
450 | + "```" | ||
451 | + ] | ||
452 | + } | ||
453 | + ] | ||
444 | } | 454 | } |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment