graykode

(add) javascript support in google colab

1 { 1 {
2 - "nbformat": 4, 2 + "nbformat": 4,
3 - "nbformat_minor": 0, 3 + "nbformat_minor": 0,
4 - "metadata": { 4 + "metadata": {
5 - "colab": { 5 + "colab": {
6 - "name": "commit-autosuggestions.ipynb", 6 + "name": "commit-autosuggestions.ipynb",
7 - "provenance": [], 7 + "provenance": [],
8 - "collapsed_sections": [], 8 + "collapsed_sections": [],
9 - "toc_visible": true 9 + "toc_visible": true
10 - },
11 - "kernelspec": {
12 - "name": "python3",
13 - "display_name": "Python 3"
14 - },
15 - "accelerator": "GPU"
16 }, 10 },
17 - "cells": [ 11 + "kernelspec": {
18 - { 12 + "name": "python3",
19 - "cell_type": "markdown", 13 + "display_name": "Python 3"
20 - "metadata": { 14 + },
21 - "id": "DZ7rFp2gzuNS" 15 + "accelerator": "GPU"
22 - }, 16 + },
23 - "source": [ 17 + "cells": [
24 - "## Start commit-autosuggestions server\n", 18 + {
25 - "Running flask app server in Google Colab for people without GPU" 19 + "cell_type": "markdown",
26 - ] 20 + "metadata": {
27 - }, 21 + "id": "DZ7rFp2gzuNS"
28 - { 22 + },
29 - "cell_type": "markdown", 23 + "source": [
30 - "metadata": { 24 + "## Start commit-autosuggestions server\n",
31 - "id": "d8Lyin2I3wHq" 25 + "Running flask app server in Google Colab for people without GPU"
32 - }, 26 + ]
33 - "source": [ 27 + },
34 - "#### Clone github repository" 28 + {
35 - ] 29 + "cell_type": "markdown",
36 - }, 30 + "metadata": {
37 - { 31 + "id": "d8Lyin2I3wHq"
38 - "cell_type": "code", 32 + },
39 - "metadata": { 33 + "source": [
40 - "id": "e_cu9igvzjcs" 34 + "#### Clone github repository"
41 - }, 35 + ]
42 - "source": [ 36 + },
43 - "!git clone https://github.com/graykode/commit-autosuggestions.git\n", 37 + {
44 - "%cd commit-autosuggestions\n", 38 + "cell_type": "code",
45 - "!pip install -r requirements.txt" 39 + "metadata": {
46 - ], 40 + "id": "e_cu9igvzjcs"
47 - "execution_count": null, 41 + },
48 - "outputs": [] 42 + "source": [
49 - }, 43 + "!git clone https://github.com/graykode/commit-autosuggestions.git\n",
50 - { 44 + "%cd commit-autosuggestions\n",
51 - "cell_type": "markdown", 45 + "!pip install -r requirements.txt"
52 - "metadata": { 46 + ],
53 - "id": "PFKn5QZr0dQx" 47 + "execution_count": null,
54 - }, 48 + "outputs": []
55 - "source": [ 49 + },
56 - "#### Download model weights\n", 50 + {
57 - "\n", 51 + "cell_type": "markdown",
58 - "Download the two weights of model from the google drive through the gdown module.\n", 52 + "metadata": {
59 - "1. [Added model](https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4) : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", 53 + "id": "PFKn5QZr0dQx"
60 - "2. [Diff model](https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m) : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code." 54 + },
61 - ] 55 + "source": [
62 - }, 56 + "#### Download model weights\n",
63 - { 57 + "\n",
64 - "cell_type": "code", 58 + "Download the two weights of model from the google drive through the gdown module.\n",
65 - "metadata": { 59 + "1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
66 - "id": "P9-EBpxt0Dp0" 60 + "2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
67 - }, 61 + "\n",
68 - "source": [ 62 + "Download pre-trained weight\n",
69 - "!pip install gdown \\\n", 63 + "\n",
70 - " && gdown \"https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4\" -O weight/added/pytorch_model.bin \\\n", 64 + "Language | Added | Diff\n",
71 - " && gdown \"https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\" -O weight/diff/pytorch_model.bin" 65 + "--- | --- | ---\n",
72 - ], 66 + "python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
73 - "execution_count": null, 67 + "javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
74 - "outputs": [] 68 + ]
75 - }, 69 + },
76 - { 70 + {
77 - "cell_type": "markdown", 71 + "cell_type": "code",
78 - "metadata": { 72 + "metadata": {
79 - "id": "org4Gqdv3iUu" 73 + "id": "P9-EBpxt0Dp0"
80 - }, 74 + },
81 - "source": [ 75 + "source": [
82 - "#### ngrok setting with flask\n", 76 + "ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
83 - "\n", 77 + "DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
84 - "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." 78 + "\n",
85 - ] 79 + "!pip install gdown \\\n",
86 - }, 80 + " && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
87 - { 81 + " && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
88 - "cell_type": "code", 82 + ],
89 - "metadata": { 83 + "execution_count": null,
90 - "id": "lZA3kuuG1Crj" 84 + "outputs": []
91 - }, 85 + },
92 - "source": [ 86 + {
93 - "!pip install flask-ngrok" 87 + "cell_type": "markdown",
94 - ], 88 + "metadata": {
95 - "execution_count": null, 89 + "id": "org4Gqdv3iUu"
96 - "outputs": [] 90 + },
97 - }, 91 + "source": [
98 - { 92 + "#### ngrok setting with flask\n",
99 - "cell_type": "markdown", 93 + "\n",
100 - "metadata": { 94 + "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
101 - "id": "hR78FRCMcqrZ" 95 + ]
102 - }, 96 + },
103 - "source": [ 97 + {
104 - "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", 98 + "cell_type": "code",
105 - "\n" 99 + "metadata": {
106 - ] 100 + "id": "lZA3kuuG1Crj"
107 - }, 101 + },
108 - { 102 + "source": [
109 - "cell_type": "code", 103 + "!pip install flask-ngrok"
110 - "metadata": { 104 + ],
111 - "id": "L_mInbOKcoc2" 105 + "execution_count": null,
112 - }, 106 + "outputs": []
113 - "source": [ 107 + },
114 - "# Paste your authtoken here in quotes\n", 108 + {
115 - "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" 109 + "cell_type": "markdown",
116 - ], 110 + "metadata": {
117 - "execution_count": null, 111 + "id": "hR78FRCMcqrZ"
118 - "outputs": [] 112 + },
119 - }, 113 + "source": [
120 - { 114 + "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
121 - "cell_type": "markdown", 115 + "\n"
122 - "metadata": { 116 + ]
123 - "id": "QwCN4YFUc0M8" 117 + },
124 - }, 118 + {
125 - "source": [ 119 + "cell_type": "code",
126 - "Set your region\n", 120 + "metadata": {
127 - "\n", 121 + "id": "L_mInbOKcoc2"
128 - "Code | Region\n", 122 + },
129 - "--- | ---\n", 123 + "source": [
130 - "us | United States\n", 124 + "# Paste your authtoken here in quotes\n",
131 - "eu | Europe\n", 125 + "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\""
132 - "ap | Asia/Pacific\n", 126 + ],
133 - "au | Australia\n", 127 + "execution_count": null,
134 - "sa | South America\n", 128 + "outputs": []
135 - "jp | Japan\n", 129 + },
136 - "in | India" 130 + {
137 - ] 131 + "cell_type": "markdown",
138 - }, 132 + "metadata": {
139 - { 133 + "id": "QwCN4YFUc0M8"
140 - "cell_type": "code", 134 + },
141 - "metadata": { 135 + "source": [
142 - "id": "p4LSNN2xc0dQ" 136 + "Set your region\n",
143 - }, 137 + "\n",
144 - "source": [ 138 + "Code | Region\n",
145 - "# Set your region here in quotes\n", 139 + "--- | ---\n",
146 - "region = \"jp\"\n", 140 + "us | United States\n",
147 - "\n", 141 + "eu | Europe\n",
148 - "# Input and output ports for communication\n", 142 + "ap | Asia/Pacific\n",
149 - "local_in_port = 5000\n", 143 + "au | Australia\n",
150 - "local_out_port = 5000" 144 + "sa | South America\n",
151 - ], 145 + "jp | Japan\n",
152 - "execution_count": null, 146 + "in | India"
153 - "outputs": [] 147 + ]
154 - }, 148 + },
155 - { 149 + {
156 - "cell_type": "code", 150 + "cell_type": "code",
157 - "metadata": { 151 + "metadata": {
158 - "id": "kg56PVrOdhi1" 152 + "id": "p4LSNN2xc0dQ"
159 - }, 153 + },
160 - "source": [ 154 + "source": [
161 - "config =\\\n", 155 + "# Set your region here in quotes\n",
162 - "f\"\"\"\n", 156 + "region = \"jp\"\n",
163 - "authtoken: {authtoken}\n", 157 + "\n",
164 - "region: {region}\n", 158 + "# Input and output ports for communication\n",
165 - "console_ui: False\n", 159 + "local_in_port = 5000\n",
166 - "tunnels:\n", 160 + "local_out_port = 5000"
167 - " input:\n", 161 + ],
168 - " addr: {local_in_port}\n", 162 + "execution_count": null,
169 - " proto: http \n", 163 + "outputs": []
170 - " output:\n", 164 + },
171 - " addr: {local_out_port}\n", 165 + {
172 - " proto: http\n", 166 + "cell_type": "code",
173 - "\"\"\"\n", 167 + "metadata": {
174 - "\n", 168 + "id": "kg56PVrOdhi1"
175 - "with open('ngrok.conf', 'w') as f:\n", 169 + },
176 - " f.write(config)" 170 + "source": [
177 - ], 171 + "config =\\\n",
178 - "execution_count": null, 172 + "f\"\"\"\n",
179 - "outputs": [] 173 + "authtoken: {authtoken}\n",
180 - }, 174 + "region: {region}\n",
181 - { 175 + "console_ui: False\n",
182 - "cell_type": "code", 176 + "tunnels:\n",
183 - "metadata": { 177 + " input:\n",
184 - "id": "hrWDrw_YdjIy" 178 + " addr: {local_in_port}\n",
185 - }, 179 + " proto: http \n",
186 - "source": [ 180 + " output:\n",
187 - "import time\n", 181 + " addr: {local_out_port}\n",
188 - "from subprocess import Popen, PIPE\n", 182 + " proto: http\n",
189 - "\n", 183 + "\"\"\"\n",
190 - "# (Re)Open tunnel\n", 184 + "\n",
191 - "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", 185 + "with open('ngrok.conf', 'w') as f:\n",
192 - "time.sleep(3)" 186 + " f.write(config)"
193 - ], 187 + ],
194 - "execution_count": null, 188 + "execution_count": null,
195 - "outputs": [] 189 + "outputs": []
196 - }, 190 + },
197 - { 191 + {
198 - "cell_type": "code", 192 + "cell_type": "code",
199 - "metadata": { 193 + "metadata": {
200 - "id": "pJgdFr0Fdjoq", 194 + "id": "hrWDrw_YdjIy"
201 - "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", 195 + },
202 - "colab": { 196 + "source": [
203 - "base_uri": "https://localhost:8080/" 197 + "import time\n",
204 - } 198 + "from subprocess import Popen, PIPE\n",
205 - }, 199 + "\n",
206 - "source": [ 200 + "# (Re)Open tunnel\n",
207 - "# Get tunnel addresses\n", 201 + "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
208 - "try:\n", 202 + "time.sleep(3)"
209 - " in_addr, out_addr = get_tunnel_adresses()\n", 203 + ],
210 - " print(\"Tunnel opened\")\n", 204 + "execution_count": null,
211 - "except Exception as e:\n", 205 + "outputs": []
212 - " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", 206 + },
213 - " print(\"Something went wrong, reopen the tunnel\")" 207 + {
214 - ], 208 + "cell_type": "code",
215 - "execution_count": null, 209 + "metadata": {
216 - "outputs": [ 210 + "id": "pJgdFr0Fdjoq",
217 - { 211 + "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809",
218 - "output_type": "stream", 212 + "colab": {
219 - "text": [ 213 + "base_uri": "https://localhost:8080/"
220 - "Opening tunnel\n", 214 + }
221 - "Something went wrong, reopen the tunnel\n" 215 + },
222 - ], 216 + "source": [
223 - "name": "stdout" 217 + "# Get tunnel addresses\n",
224 - } 218 + "try:\n",
225 - ] 219 + " in_addr, out_addr = get_tunnel_adresses()\n",
226 - }, 220 + " print(\"Tunnel opened\")\n",
227 - { 221 + "except Exception as e:\n",
228 - "cell_type": "markdown", 222 + " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
229 - "metadata": { 223 + " print(\"Something went wrong, reopen the tunnel\")"
230 - "id": "cEZ-O0wz74OJ" 224 + ],
231 - }, 225 + "execution_count": null,
232 - "source": [ 226 + "outputs": [
233 - "#### Run you server!"
234 - ]
235 - },
236 - {
237 - "cell_type": "code",
238 - "metadata": {
239 - "id": "7PRkeYTL8Y_6"
240 - },
241 - "source": [
242 - "import os\n",
243 - "import torch\n",
244 - "import argparse\n",
245 - "from tqdm import tqdm\n",
246 - "import torch.nn as nn\n",
247 - "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
248 - "from transformers import (RobertaConfig, RobertaTokenizer)\n",
249 - "\n",
250 - "from commit.model import Seq2Seq\n",
251 - "from commit.utils import (Example, convert_examples_to_features)\n",
252 - "from commit.model.diff_roberta import RobertaModel\n",
253 - "\n",
254 - "from flask import Flask, jsonify, request\n",
255 - "\n",
256 - "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
257 - ],
258 - "execution_count": null,
259 - "outputs": []
260 - },
261 - {
262 - "cell_type": "code",
263 - "metadata": {
264 - "id": "CiJKucX17qb4"
265 - },
266 - "source": [
267 - "def get_model(model_class, config, tokenizer, mode):\n",
268 - " encoder = model_class(config=config)\n",
269 - " decoder_layer = nn.TransformerDecoderLayer(\n",
270 - " d_model=config.hidden_size, nhead=config.num_attention_heads\n",
271 - " )\n",
272 - " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
273 - " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
274 - " beam_size=args.beam_size, max_length=args.max_target_length,\n",
275 - " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
276 - "\n",
277 - " assert args.load_model_path\n",
278 - " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
279 - "\n",
280 - " model.load_state_dict(\n",
281 - " torch.load(\n",
282 - " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
283 - " map_location=torch.device('cpu')\n",
284 - " ),\n",
285 - " strict=False\n",
286 - " )\n",
287 - " return model\n",
288 - "\n",
289 - "def get_features(examples):\n",
290 - " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
291 - " all_source_ids = torch.tensor(\n",
292 - " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
293 - " )\n",
294 - " all_source_mask = torch.tensor(\n",
295 - " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
296 - " )\n",
297 - " all_patch_ids = torch.tensor(\n",
298 - " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
299 - " )\n",
300 - " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
301 - "\n",
302 - "def create_app():\n",
303 - " @app.route('/')\n",
304 - " def index():\n",
305 - " return jsonify(hello=\"world\")\n",
306 - "\n",
307 - " @app.route('/added', methods=['POST'])\n",
308 - " def added():\n",
309 - " if request.method == 'POST':\n",
310 - " payload = request.get_json()\n",
311 - " example = [\n",
312 - " Example(\n",
313 - " idx=payload['idx'],\n",
314 - " added=payload['added'],\n",
315 - " deleted=payload['deleted'],\n",
316 - " target=None\n",
317 - " )\n",
318 - " ]\n",
319 - " message = inference(model=args.added_model, data=get_features(example))\n",
320 - " return jsonify(idx=payload['idx'], message=message)\n",
321 - "\n",
322 - " @app.route('/diff', methods=['POST'])\n",
323 - " def diff():\n",
324 - " if request.method == 'POST':\n",
325 - " payload = request.get_json()\n",
326 - " example = [\n",
327 - " Example(\n",
328 - " idx=payload['idx'],\n",
329 - " added=payload['added'],\n",
330 - " deleted=payload['deleted'],\n",
331 - " target=None\n",
332 - " )\n",
333 - " ]\n",
334 - " message = inference(model=args.diff_model, data=get_features(example))\n",
335 - " return jsonify(idx=payload['idx'], message=message)\n",
336 - "\n",
337 - " @app.route('/tokenizer', methods=['POST'])\n",
338 - " def tokenizer():\n",
339 - " if request.method == 'POST':\n",
340 - " payload = request.get_json()\n",
341 - " tokens = args.tokenizer.tokenize(payload['code'])\n",
342 - " return jsonify(tokens=tokens)\n",
343 - "\n",
344 - " return app\n",
345 - "\n",
346 - "def inference(model, data):\n",
347 - " # Calculate bleu\n",
348 - " eval_sampler = SequentialSampler(data)\n",
349 - " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
350 - "\n",
351 - " model.eval()\n",
352 - " p=[]\n",
353 - " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
354 - " batch = tuple(t.to(args.device) for t in batch)\n",
355 - " source_ids, source_mask, patch_ids = batch\n",
356 - " with torch.no_grad():\n",
357 - " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
358 - " for pred in preds:\n",
359 - " t = pred[0].cpu().numpy()\n",
360 - " t = list(t)\n",
361 - " if 0 in t:\n",
362 - " t = t[:t.index(0)]\n",
363 - " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
364 - " p.append(text)\n",
365 - " return p"
366 - ],
367 - "execution_count": null,
368 - "outputs": []
369 - },
370 - {
371 - "cell_type": "markdown",
372 - "metadata": {
373 - "id": "Esf4r-Ai8cG3"
374 - },
375 - "source": [
376 - "**Set enviroment**"
377 - ]
378 - },
379 - {
380 - "cell_type": "code",
381 - "metadata": {
382 - "id": "mR7gVmSoSUoy"
383 - },
384 - "source": [
385 - "import easydict \n",
386 - "\n",
387 - "args = easydict.EasyDict({\n",
388 - " 'load_model_path': 'weight/', \n",
389 - " 'model_type': 'roberta',\n",
390 - " 'config_name' : 'microsoft/codebert-base',\n",
391 - " 'tokenizer_name' : 'microsoft/codebert-base',\n",
392 - " 'max_source_length' : 512,\n",
393 - " 'max_target_length' : 128,\n",
394 - " 'beam_size' : 10,\n",
395 - " 'do_lower_case' : False,\n",
396 - " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
397 - "})"
398 - ],
399 - "execution_count": null,
400 - "outputs": []
401 - },
402 - {
403 - "cell_type": "code",
404 - "metadata": {
405 - "id": "e8dk5RwvToOv"
406 - },
407 - "source": [
408 - "# flask_ngrok_example.py\n",
409 - "from flask_ngrok import run_with_ngrok\n",
410 - "\n",
411 - "app = Flask(__name__)\n",
412 - "run_with_ngrok(app) # Start ngrok when app is run\n",
413 - "\n",
414 - "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
415 - "config = config_class.from_pretrained(args.config_name)\n",
416 - "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
417 - "\n",
418 - "# budild model\n",
419 - "args.added_model =get_model(model_class=model_class, config=config,\n",
420 - " tokenizer=args.tokenizer, mode='added').to(args.device)\n",
421 - "args.diff_model = get_model(model_class=model_class, config=config,\n",
422 - " tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
423 - "\n",
424 - "app = create_app()\n",
425 - "app.run()"
426 - ],
427 - "execution_count": null,
428 - "outputs": []
429 - },
430 { 227 {
431 - "cell_type": "markdown", 228 + "output_type": "stream",
432 - "metadata": { 229 + "text": [
433 - "id": "DXkBcO_sU_VN" 230 + "Opening tunnel\n",
434 - }, 231 + "Something went wrong, reopen the tunnel\n"
435 - "source": [ 232 + ],
436 - "## Set commit configure\n", 233 + "name": "stdout"
437 - "Now, set commit configure on your local computer.\n",
438 - "```shell\n",
439 - "$ commit configure --endpoint http://********.ngrok.io\n",
440 - "```"
441 - ]
442 } 234 }
443 - ] 235 + ]
236 + },
237 + {
238 + "cell_type": "markdown",
239 + "metadata": {
240 + "id": "cEZ-O0wz74OJ"
241 + },
242 + "source": [
243 + "#### Run you server!"
244 + ]
245 + },
246 + {
247 + "cell_type": "code",
248 + "metadata": {
249 + "id": "7PRkeYTL8Y_6"
250 + },
251 + "source": [
252 + "import os\n",
253 + "import torch\n",
254 + "import argparse\n",
255 + "from tqdm import tqdm\n",
256 + "import torch.nn as nn\n",
257 + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
258 + "from transformers import (RobertaConfig, RobertaTokenizer)\n",
259 + "\n",
260 + "from commit.model import Seq2Seq\n",
261 + "from commit.utils import (Example, convert_examples_to_features)\n",
262 + "from commit.model.diff_roberta import RobertaModel\n",
263 + "\n",
264 + "from flask import Flask, jsonify, request\n",
265 + "\n",
266 + "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
267 + ],
268 + "execution_count": null,
269 + "outputs": []
270 + },
271 + {
272 + "cell_type": "code",
273 + "metadata": {
274 + "id": "CiJKucX17qb4"
275 + },
276 + "source": [
277 + "def get_model(model_class, config, tokenizer, mode):\n",
278 + " encoder = model_class(config=config)\n",
279 + " decoder_layer = nn.TransformerDecoderLayer(\n",
280 + " d_model=config.hidden_size, nhead=config.num_attention_heads\n",
281 + " )\n",
282 + " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
283 + " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
284 + " beam_size=args.beam_size, max_length=args.max_target_length,\n",
285 + " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
286 + "\n",
287 + " assert args.load_model_path\n",
288 + " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
289 + "\n",
290 + " model.load_state_dict(\n",
291 + " torch.load(\n",
292 + " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
293 + " map_location=torch.device('cpu')\n",
294 + " ),\n",
295 + " strict=False\n",
296 + " )\n",
297 + " return model\n",
298 + "\n",
299 + "def get_features(examples):\n",
300 + " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
301 + " all_source_ids = torch.tensor(\n",
302 + " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
303 + " )\n",
304 + " all_source_mask = torch.tensor(\n",
305 + " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
306 + " )\n",
307 + " all_patch_ids = torch.tensor(\n",
308 + " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
309 + " )\n",
310 + " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
311 + "\n",
312 + "def create_app():\n",
313 + " @app.route('/')\n",
314 + " def index():\n",
315 + " return jsonify(hello=\"world\")\n",
316 + "\n",
317 + " @app.route('/added', methods=['POST'])\n",
318 + " def added():\n",
319 + " if request.method == 'POST':\n",
320 + " payload = request.get_json()\n",
321 + " example = [\n",
322 + " Example(\n",
323 + " idx=payload['idx'],\n",
324 + " added=payload['added'],\n",
325 + " deleted=payload['deleted'],\n",
326 + " target=None\n",
327 + " )\n",
328 + " ]\n",
329 + " message = inference(model=args.added_model, data=get_features(example))\n",
330 + " return jsonify(idx=payload['idx'], message=message)\n",
331 + "\n",
332 + " @app.route('/diff', methods=['POST'])\n",
333 + " def diff():\n",
334 + " if request.method == 'POST':\n",
335 + " payload = request.get_json()\n",
336 + " example = [\n",
337 + " Example(\n",
338 + " idx=payload['idx'],\n",
339 + " added=payload['added'],\n",
340 + " deleted=payload['deleted'],\n",
341 + " target=None\n",
342 + " )\n",
343 + " ]\n",
344 + " message = inference(model=args.diff_model, data=get_features(example))\n",
345 + " return jsonify(idx=payload['idx'], message=message)\n",
346 + "\n",
347 + " @app.route('/tokenizer', methods=['POST'])\n",
348 + " def tokenizer():\n",
349 + " if request.method == 'POST':\n",
350 + " payload = request.get_json()\n",
351 + " tokens = args.tokenizer.tokenize(payload['code'])\n",
352 + " return jsonify(tokens=tokens)\n",
353 + "\n",
354 + " return app\n",
355 + "\n",
356 + "def inference(model, data):\n",
357 + " # Calculate bleu\n",
358 + " eval_sampler = SequentialSampler(data)\n",
359 + " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
360 + "\n",
361 + " model.eval()\n",
362 + " p=[]\n",
363 + " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
364 + " batch = tuple(t.to(args.device) for t in batch)\n",
365 + " source_ids, source_mask, patch_ids = batch\n",
366 + " with torch.no_grad():\n",
367 + " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
368 + " for pred in preds:\n",
369 + " t = pred[0].cpu().numpy()\n",
370 + " t = list(t)\n",
371 + " if 0 in t:\n",
372 + " t = t[:t.index(0)]\n",
373 + " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
374 + " p.append(text)\n",
375 + " return p"
376 + ],
377 + "execution_count": null,
378 + "outputs": []
379 + },
380 + {
381 + "cell_type": "markdown",
382 + "metadata": {
383 + "id": "Esf4r-Ai8cG3"
384 + },
385 + "source": [
386 + "**Set enviroment**"
387 + ]
388 + },
389 + {
390 + "cell_type": "code",
391 + "metadata": {
392 + "id": "mR7gVmSoSUoy"
393 + },
394 + "source": [
395 + "import easydict \n",
396 + "\n",
397 + "args = easydict.EasyDict({\n",
398 + " 'load_model_path': 'weight/', \n",
399 + " 'model_type': 'roberta',\n",
400 + " 'config_name' : 'microsoft/codebert-base',\n",
401 + " 'tokenizer_name' : 'microsoft/codebert-base',\n",
402 + " 'max_source_length' : 512,\n",
403 + " 'max_target_length' : 128,\n",
404 + " 'beam_size' : 10,\n",
405 + " 'do_lower_case' : False,\n",
406 + " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
407 + "})"
408 + ],
409 + "execution_count": null,
410 + "outputs": []
411 + },
412 + {
413 + "cell_type": "code",
414 + "metadata": {
415 + "id": "e8dk5RwvToOv"
416 + },
417 + "source": [
418 + "# flask_ngrok_example.py\n",
419 + "from flask_ngrok import run_with_ngrok\n",
420 + "\n",
421 + "app = Flask(__name__)\n",
422 + "run_with_ngrok(app) # Start ngrok when app is run\n",
423 + "\n",
424 + "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
425 + "config = config_class.from_pretrained(args.config_name)\n",
426 + "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
427 + "\n",
428 + "# budild model\n",
429 + "args.added_model =get_model(model_class=model_class, config=config,\n",
430 + " tokenizer=args.tokenizer, mode='added').to(args.device)\n",
431 + "args.diff_model = get_model(model_class=model_class, config=config,\n",
432 + " tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
433 + "\n",
434 + "app = create_app()\n",
435 + "app.run()"
436 + ],
437 + "execution_count": null,
438 + "outputs": []
439 + },
440 + {
441 + "cell_type": "markdown",
442 + "metadata": {
443 + "id": "DXkBcO_sU_VN"
444 + },
445 + "source": [
446 + "## Set commit configure\n",
447 + "Now, set commit configure on your local computer.\n",
448 + "```shell\n",
449 + "$ commit configure --endpoint http://********.ngrok.io\n",
450 + "```"
451 + ]
452 + }
453 + ]
444 } 454 }
...\ No newline at end of file ...\ No newline at end of file
......