(add) jupyter notebook file in goolge colab for people who dont have nvidia gpu
Showing
1 changed file
with
444 additions
and
0 deletions
commit_autosuggestions.ipynb
0 → 100644
1 | +{ | ||
2 | + "nbformat": 4, | ||
3 | + "nbformat_minor": 0, | ||
4 | + "metadata": { | ||
5 | + "colab": { | ||
6 | + "name": "commit-autosuggestions.ipynb", | ||
7 | + "provenance": [], | ||
8 | + "collapsed_sections": [], | ||
9 | + "toc_visible": true | ||
10 | + }, | ||
11 | + "kernelspec": { | ||
12 | + "name": "python3", | ||
13 | + "display_name": "Python 3" | ||
14 | + }, | ||
15 | + "accelerator": "GPU" | ||
16 | + }, | ||
17 | + "cells": [ | ||
18 | + { | ||
19 | + "cell_type": "markdown", | ||
20 | + "metadata": { | ||
21 | + "id": "DZ7rFp2gzuNS" | ||
22 | + }, | ||
23 | + "source": [ | ||
24 | + "## Start commit-autosuggestions server\n", | ||
25 | + "Running flask app server in Google Colab for people without GPU" | ||
26 | + ] | ||
27 | + }, | ||
28 | + { | ||
29 | + "cell_type": "markdown", | ||
30 | + "metadata": { | ||
31 | + "id": "d8Lyin2I3wHq" | ||
32 | + }, | ||
33 | + "source": [ | ||
34 | + "#### Clone github repository" | ||
35 | + ] | ||
36 | + }, | ||
37 | + { | ||
38 | + "cell_type": "code", | ||
39 | + "metadata": { | ||
40 | + "id": "e_cu9igvzjcs" | ||
41 | + }, | ||
42 | + "source": [ | ||
43 | + "!git clone https://github.com/graykode/commit-autosuggestions.git\n", | ||
44 | + "%cd commit-autosuggestions\n", | ||
45 | + "!pip install -r requirements.txt" | ||
46 | + ], | ||
47 | + "execution_count": null, | ||
48 | + "outputs": [] | ||
49 | + }, | ||
50 | + { | ||
51 | + "cell_type": "markdown", | ||
52 | + "metadata": { | ||
53 | + "id": "PFKn5QZr0dQx" | ||
54 | + }, | ||
55 | + "source": [ | ||
56 | + "#### Download model weights\n", | ||
57 | + "\n", | ||
58 | + "Download the two weights of model from the google drive through the gdown module.\n", | ||
59 | + "1. [Added model](https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4) : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", | ||
60 | + "2. [Diff model](https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m) : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code." | ||
61 | + ] | ||
62 | + }, | ||
63 | + { | ||
64 | + "cell_type": "code", | ||
65 | + "metadata": { | ||
66 | + "id": "P9-EBpxt0Dp0" | ||
67 | + }, | ||
68 | + "source": [ | ||
69 | + "!pip install gdown \\\n", | ||
70 | + " && gdown \"https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4\" -O weight/added/pytorch_model.bin \\\n", | ||
71 | + " && gdown \"https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\" -O weight/diff/pytorch_model.bin" | ||
72 | + ], | ||
73 | + "execution_count": null, | ||
74 | + "outputs": [] | ||
75 | + }, | ||
76 | + { | ||
77 | + "cell_type": "markdown", | ||
78 | + "metadata": { | ||
79 | + "id": "org4Gqdv3iUu" | ||
80 | + }, | ||
81 | + "source": [ | ||
82 | + "#### ngrok setting with flask\n", | ||
83 | + "\n", | ||
84 | + "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." | ||
85 | + ] | ||
86 | + }, | ||
87 | + { | ||
88 | + "cell_type": "code", | ||
89 | + "metadata": { | ||
90 | + "id": "lZA3kuuG1Crj" | ||
91 | + }, | ||
92 | + "source": [ | ||
93 | + "!pip install flask-ngrok" | ||
94 | + ], | ||
95 | + "execution_count": null, | ||
96 | + "outputs": [] | ||
97 | + }, | ||
98 | + { | ||
99 | + "cell_type": "markdown", | ||
100 | + "metadata": { | ||
101 | + "id": "hR78FRCMcqrZ" | ||
102 | + }, | ||
103 | + "source": [ | ||
104 | + "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", | ||
105 | + "\n" | ||
106 | + ] | ||
107 | + }, | ||
108 | + { | ||
109 | + "cell_type": "code", | ||
110 | + "metadata": { | ||
111 | + "id": "L_mInbOKcoc2" | ||
112 | + }, | ||
113 | + "source": [ | ||
114 | + "# Paste your authtoken here in quotes\n", | ||
115 | + "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" | ||
116 | + ], | ||
117 | + "execution_count": null, | ||
118 | + "outputs": [] | ||
119 | + }, | ||
120 | + { | ||
121 | + "cell_type": "markdown", | ||
122 | + "metadata": { | ||
123 | + "id": "QwCN4YFUc0M8" | ||
124 | + }, | ||
125 | + "source": [ | ||
126 | + "Set your region\n", | ||
127 | + "\n", | ||
128 | + "Code | Region\n", | ||
129 | + "--- | ---\n", | ||
130 | + "us | United States\n", | ||
131 | + "eu | Europe\n", | ||
132 | + "ap | Asia/Pacific\n", | ||
133 | + "au | Australia\n", | ||
134 | + "sa | South America\n", | ||
135 | + "jp | Japan\n", | ||
136 | + "in | India" | ||
137 | + ] | ||
138 | + }, | ||
139 | + { | ||
140 | + "cell_type": "code", | ||
141 | + "metadata": { | ||
142 | + "id": "p4LSNN2xc0dQ" | ||
143 | + }, | ||
144 | + "source": [ | ||
145 | + "# Set your region here in quotes\n", | ||
146 | + "region = \"jp\"\n", | ||
147 | + "\n", | ||
148 | + "# Input and output ports for communication\n", | ||
149 | + "local_in_port = 5000\n", | ||
150 | + "local_out_port = 5000" | ||
151 | + ], | ||
152 | + "execution_count": null, | ||
153 | + "outputs": [] | ||
154 | + }, | ||
155 | + { | ||
156 | + "cell_type": "code", | ||
157 | + "metadata": { | ||
158 | + "id": "kg56PVrOdhi1" | ||
159 | + }, | ||
160 | + "source": [ | ||
161 | + "config =\\\n", | ||
162 | + "f\"\"\"\n", | ||
163 | + "authtoken: {authtoken}\n", | ||
164 | + "region: {region}\n", | ||
165 | + "console_ui: False\n", | ||
166 | + "tunnels:\n", | ||
167 | + " input:\n", | ||
168 | + " addr: {local_in_port}\n", | ||
169 | + " proto: http \n", | ||
170 | + " output:\n", | ||
171 | + " addr: {local_out_port}\n", | ||
172 | + " proto: http\n", | ||
173 | + "\"\"\"\n", | ||
174 | + "\n", | ||
175 | + "with open('ngrok.conf', 'w') as f:\n", | ||
176 | + " f.write(config)" | ||
177 | + ], | ||
178 | + "execution_count": null, | ||
179 | + "outputs": [] | ||
180 | + }, | ||
181 | + { | ||
182 | + "cell_type": "code", | ||
183 | + "metadata": { | ||
184 | + "id": "hrWDrw_YdjIy" | ||
185 | + }, | ||
186 | + "source": [ | ||
187 | + "import time\n", | ||
188 | + "from subprocess import Popen, PIPE\n", | ||
189 | + "\n", | ||
190 | + "# (Re)Open tunnel\n", | ||
191 | + "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", | ||
192 | + "time.sleep(3)" | ||
193 | + ], | ||
194 | + "execution_count": null, | ||
195 | + "outputs": [] | ||
196 | + }, | ||
197 | + { | ||
198 | + "cell_type": "code", | ||
199 | + "metadata": { | ||
200 | + "id": "pJgdFr0Fdjoq", | ||
201 | + "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", | ||
202 | + "colab": { | ||
203 | + "base_uri": "https://localhost:8080/" | ||
204 | + } | ||
205 | + }, | ||
206 | + "source": [ | ||
207 | + "# Get tunnel addresses\n", | ||
208 | + "try:\n", | ||
209 | + " in_addr, out_addr = get_tunnel_adresses()\n", | ||
210 | + " print(\"Tunnel opened\")\n", | ||
211 | + "except Exception as e:\n", | ||
212 | + " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", | ||
213 | + " print(\"Something went wrong, reopen the tunnel\")" | ||
214 | + ], | ||
215 | + "execution_count": null, | ||
216 | + "outputs": [ | ||
217 | + { | ||
218 | + "output_type": "stream", | ||
219 | + "text": [ | ||
220 | + "Opening tunnel\n", | ||
221 | + "Something went wrong, reopen the tunnel\n" | ||
222 | + ], | ||
223 | + "name": "stdout" | ||
224 | + } | ||
225 | + ] | ||
226 | + }, | ||
227 | + { | ||
228 | + "cell_type": "markdown", | ||
229 | + "metadata": { | ||
230 | + "id": "cEZ-O0wz74OJ" | ||
231 | + }, | ||
232 | + "source": [ | ||
233 | + "#### Run you server!" | ||
234 | + ] | ||
235 | + }, | ||
236 | + { | ||
237 | + "cell_type": "code", | ||
238 | + "metadata": { | ||
239 | + "id": "7PRkeYTL8Y_6" | ||
240 | + }, | ||
241 | + "source": [ | ||
242 | + "import os\n", | ||
243 | + "import torch\n", | ||
244 | + "import argparse\n", | ||
245 | + "from tqdm import tqdm\n", | ||
246 | + "import torch.nn as nn\n", | ||
247 | + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", | ||
248 | + "from transformers import (RobertaConfig, RobertaTokenizer)\n", | ||
249 | + "\n", | ||
250 | + "from commit.model import Seq2Seq\n", | ||
251 | + "from commit.utils import (Example, convert_examples_to_features)\n", | ||
252 | + "from commit.model.diff_roberta import RobertaModel\n", | ||
253 | + "\n", | ||
254 | + "from flask import Flask, jsonify, request\n", | ||
255 | + "\n", | ||
256 | + "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}" | ||
257 | + ], | ||
258 | + "execution_count": null, | ||
259 | + "outputs": [] | ||
260 | + }, | ||
261 | + { | ||
262 | + "cell_type": "code", | ||
263 | + "metadata": { | ||
264 | + "id": "CiJKucX17qb4" | ||
265 | + }, | ||
266 | + "source": [ | ||
267 | + "def get_model(model_class, config, tokenizer, mode):\n", | ||
268 | + " encoder = model_class(config=config)\n", | ||
269 | + " decoder_layer = nn.TransformerDecoderLayer(\n", | ||
270 | + " d_model=config.hidden_size, nhead=config.num_attention_heads\n", | ||
271 | + " )\n", | ||
272 | + " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", | ||
273 | + " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n", | ||
274 | + " beam_size=args.beam_size, max_length=args.max_target_length,\n", | ||
275 | + " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n", | ||
276 | + "\n", | ||
277 | + " assert args.load_model_path\n", | ||
278 | + " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n", | ||
279 | + "\n", | ||
280 | + " model.load_state_dict(\n", | ||
281 | + " torch.load(\n", | ||
282 | + " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n", | ||
283 | + " map_location=torch.device('cpu')\n", | ||
284 | + " ),\n", | ||
285 | + " strict=False\n", | ||
286 | + " )\n", | ||
287 | + " return model\n", | ||
288 | + "\n", | ||
289 | + "def get_features(examples):\n", | ||
290 | + " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n", | ||
291 | + " all_source_ids = torch.tensor(\n", | ||
292 | + " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
293 | + " )\n", | ||
294 | + " all_source_mask = torch.tensor(\n", | ||
295 | + " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
296 | + " )\n", | ||
297 | + " all_patch_ids = torch.tensor(\n", | ||
298 | + " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
299 | + " )\n", | ||
300 | + " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n", | ||
301 | + "\n", | ||
302 | + "def create_app():\n", | ||
303 | + " @app.route('/')\n", | ||
304 | + " def index():\n", | ||
305 | + " return jsonify(hello=\"world\")\n", | ||
306 | + "\n", | ||
307 | + " @app.route('/added', methods=['POST'])\n", | ||
308 | + " def added():\n", | ||
309 | + " if request.method == 'POST':\n", | ||
310 | + " payload = request.get_json()\n", | ||
311 | + " example = [\n", | ||
312 | + " Example(\n", | ||
313 | + " idx=payload['idx'],\n", | ||
314 | + " added=payload['added'],\n", | ||
315 | + " deleted=payload['deleted'],\n", | ||
316 | + " target=None\n", | ||
317 | + " )\n", | ||
318 | + " ]\n", | ||
319 | + " message = inference(model=args.added_model, data=get_features(example))\n", | ||
320 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
321 | + "\n", | ||
322 | + " @app.route('/diff', methods=['POST'])\n", | ||
323 | + " def diff():\n", | ||
324 | + " if request.method == 'POST':\n", | ||
325 | + " payload = request.get_json()\n", | ||
326 | + " example = [\n", | ||
327 | + " Example(\n", | ||
328 | + " idx=payload['idx'],\n", | ||
329 | + " added=payload['added'],\n", | ||
330 | + " deleted=payload['deleted'],\n", | ||
331 | + " target=None\n", | ||
332 | + " )\n", | ||
333 | + " ]\n", | ||
334 | + " message = inference(model=args.diff_model, data=get_features(example))\n", | ||
335 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
336 | + "\n", | ||
337 | + " @app.route('/tokenizer', methods=['POST'])\n", | ||
338 | + " def tokenizer():\n", | ||
339 | + " if request.method == 'POST':\n", | ||
340 | + " payload = request.get_json()\n", | ||
341 | + " tokens = args.tokenizer.tokenize(payload['code'])\n", | ||
342 | + " return jsonify(tokens=tokens)\n", | ||
343 | + "\n", | ||
344 | + " return app\n", | ||
345 | + "\n", | ||
346 | + "def inference(model, data):\n", | ||
347 | + " # Calculate bleu\n", | ||
348 | + " eval_sampler = SequentialSampler(data)\n", | ||
349 | + " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n", | ||
350 | + "\n", | ||
351 | + " model.eval()\n", | ||
352 | + " p=[]\n", | ||
353 | + " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n", | ||
354 | + " batch = tuple(t.to(args.device) for t in batch)\n", | ||
355 | + " source_ids, source_mask, patch_ids = batch\n", | ||
356 | + " with torch.no_grad():\n", | ||
357 | + " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n", | ||
358 | + " for pred in preds:\n", | ||
359 | + " t = pred[0].cpu().numpy()\n", | ||
360 | + " t = list(t)\n", | ||
361 | + " if 0 in t:\n", | ||
362 | + " t = t[:t.index(0)]\n", | ||
363 | + " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n", | ||
364 | + " p.append(text)\n", | ||
365 | + " return p" | ||
366 | + ], | ||
367 | + "execution_count": null, | ||
368 | + "outputs": [] | ||
369 | + }, | ||
370 | + { | ||
371 | + "cell_type": "markdown", | ||
372 | + "metadata": { | ||
373 | + "id": "Esf4r-Ai8cG3" | ||
374 | + }, | ||
375 | + "source": [ | ||
376 | + "**Set enviroment**" | ||
377 | + ] | ||
378 | + }, | ||
379 | + { | ||
380 | + "cell_type": "code", | ||
381 | + "metadata": { | ||
382 | + "id": "mR7gVmSoSUoy" | ||
383 | + }, | ||
384 | + "source": [ | ||
385 | + "import easydict \n", | ||
386 | + "\n", | ||
387 | + "args = easydict.EasyDict({\n", | ||
388 | + " 'load_model_path': 'weight/', \n", | ||
389 | + " 'model_type': 'roberta',\n", | ||
390 | + " 'config_name' : 'microsoft/codebert-base',\n", | ||
391 | + " 'tokenizer_name' : 'microsoft/codebert-base',\n", | ||
392 | + " 'max_source_length' : 512,\n", | ||
393 | + " 'max_target_length' : 128,\n", | ||
394 | + " 'beam_size' : 10,\n", | ||
395 | + " 'do_lower_case' : False,\n", | ||
396 | + " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
397 | + "})" | ||
398 | + ], | ||
399 | + "execution_count": null, | ||
400 | + "outputs": [] | ||
401 | + }, | ||
402 | + { | ||
403 | + "cell_type": "code", | ||
404 | + "metadata": { | ||
405 | + "id": "e8dk5RwvToOv" | ||
406 | + }, | ||
407 | + "source": [ | ||
408 | + "# flask_ngrok_example.py\n", | ||
409 | + "from flask_ngrok import run_with_ngrok\n", | ||
410 | + "\n", | ||
411 | + "app = Flask(__name__)\n", | ||
412 | + "run_with_ngrok(app) # Start ngrok when app is run\n", | ||
413 | + "\n", | ||
414 | + "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n", | ||
415 | + "config = config_class.from_pretrained(args.config_name)\n", | ||
416 | + "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n", | ||
417 | + "\n", | ||
418 | + "# budild model\n", | ||
419 | + "args.added_model =get_model(model_class=model_class, config=config,\n", | ||
420 | + " tokenizer=args.tokenizer, mode='added').to(args.device)\n", | ||
421 | + "args.diff_model = get_model(model_class=model_class, config=config,\n", | ||
422 | + " tokenizer=args.tokenizer, mode='diff').to(args.device)\n", | ||
423 | + "\n", | ||
424 | + "app = create_app()\n", | ||
425 | + "app.run()" | ||
426 | + ], | ||
427 | + "execution_count": null, | ||
428 | + "outputs": [] | ||
429 | + }, | ||
430 | + { | ||
431 | + "cell_type": "markdown", | ||
432 | + "metadata": { | ||
433 | + "id": "DXkBcO_sU_VN" | ||
434 | + }, | ||
435 | + "source": [ | ||
436 | + "## Set commit configure\n", | ||
437 | + "Now, set commit configure on your local computer.\n", | ||
438 | + "```shell\n", | ||
439 | + "$ commit configure --endpoint http://********.ngrok.io\n", | ||
440 | + "```" | ||
441 | + ] | ||
442 | + } | ||
443 | + ] | ||
444 | +} | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment