Showing
1 changed file
with
449 additions
and
439 deletions
| 1 | { | 1 | { |
| 2 | - "nbformat": 4, | 2 | + "nbformat": 4, |
| 3 | - "nbformat_minor": 0, | 3 | + "nbformat_minor": 0, |
| 4 | - "metadata": { | 4 | + "metadata": { |
| 5 | - "colab": { | 5 | + "colab": { |
| 6 | - "name": "commit-autosuggestions.ipynb", | 6 | + "name": "commit-autosuggestions.ipynb", |
| 7 | - "provenance": [], | 7 | + "provenance": [], |
| 8 | - "collapsed_sections": [], | 8 | + "collapsed_sections": [], |
| 9 | - "toc_visible": true | 9 | + "toc_visible": true |
| 10 | - }, | ||
| 11 | - "kernelspec": { | ||
| 12 | - "name": "python3", | ||
| 13 | - "display_name": "Python 3" | ||
| 14 | - }, | ||
| 15 | - "accelerator": "GPU" | ||
| 16 | }, | 10 | }, |
| 17 | - "cells": [ | 11 | + "kernelspec": { |
| 18 | - { | 12 | + "name": "python3", |
| 19 | - "cell_type": "markdown", | 13 | + "display_name": "Python 3" |
| 20 | - "metadata": { | 14 | + }, |
| 21 | - "id": "DZ7rFp2gzuNS" | 15 | + "accelerator": "GPU" |
| 22 | - }, | 16 | + }, |
| 23 | - "source": [ | 17 | + "cells": [ |
| 24 | - "## Start commit-autosuggestions server\n", | 18 | + { |
| 25 | - "Running flask app server in Google Colab for people without GPU" | 19 | + "cell_type": "markdown", |
| 26 | - ] | 20 | + "metadata": { |
| 27 | - }, | 21 | + "id": "DZ7rFp2gzuNS" |
| 28 | - { | 22 | + }, |
| 29 | - "cell_type": "markdown", | 23 | + "source": [ |
| 30 | - "metadata": { | 24 | + "## Start commit-autosuggestions server\n", |
| 31 | - "id": "d8Lyin2I3wHq" | 25 | + "Running flask app server in Google Colab for people without GPU" |
| 32 | - }, | 26 | + ] |
| 33 | - "source": [ | 27 | + }, |
| 34 | - "#### Clone github repository" | 28 | + { |
| 35 | - ] | 29 | + "cell_type": "markdown", |
| 36 | - }, | 30 | + "metadata": { |
| 37 | - { | 31 | + "id": "d8Lyin2I3wHq" |
| 38 | - "cell_type": "code", | 32 | + }, |
| 39 | - "metadata": { | 33 | + "source": [ |
| 40 | - "id": "e_cu9igvzjcs" | 34 | + "#### Clone github repository" |
| 41 | - }, | 35 | + ] |
| 42 | - "source": [ | 36 | + }, |
| 43 | - "!git clone https://github.com/graykode/commit-autosuggestions.git\n", | 37 | + { |
| 44 | - "%cd commit-autosuggestions\n", | 38 | + "cell_type": "code", |
| 45 | - "!pip install -r requirements.txt" | 39 | + "metadata": { |
| 46 | - ], | 40 | + "id": "e_cu9igvzjcs" |
| 47 | - "execution_count": null, | 41 | + }, |
| 48 | - "outputs": [] | 42 | + "source": [ |
| 49 | - }, | 43 | + "!git clone https://github.com/graykode/commit-autosuggestions.git\n", |
| 50 | - { | 44 | + "%cd commit-autosuggestions\n", |
| 51 | - "cell_type": "markdown", | 45 | + "!pip install -r requirements.txt" |
| 52 | - "metadata": { | 46 | + ], |
| 53 | - "id": "PFKn5QZr0dQx" | 47 | + "execution_count": null, |
| 54 | - }, | 48 | + "outputs": [] |
| 55 | - "source": [ | 49 | + }, |
| 56 | - "#### Download model weights\n", | 50 | + { |
| 57 | - "\n", | 51 | + "cell_type": "markdown", |
| 58 | - "Download the two weights of model from the google drive through the gdown module.\n", | 52 | + "metadata": { |
| 59 | - "1. [Added model](https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4) : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", | 53 | + "id": "PFKn5QZr0dQx" |
| 60 | - "2. [Diff model](https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m) : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code." | 54 | + }, |
| 61 | - ] | 55 | + "source": [ |
| 62 | - }, | 56 | + "#### Download model weights\n", |
| 63 | - { | 57 | + "\n", |
| 64 | - "cell_type": "code", | 58 | + "Download the two weights of model from the google drive through the gdown module.\n", |
| 65 | - "metadata": { | 59 | + "1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n", |
| 66 | - "id": "P9-EBpxt0Dp0" | 60 | + "2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n", |
| 67 | - }, | 61 | + "\n", |
| 68 | - "source": [ | 62 | + "Download pre-trained weight\n", |
| 69 | - "!pip install gdown \\\n", | 63 | + "\n", |
| 70 | - " && gdown \"https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4\" -O weight/added/pytorch_model.bin \\\n", | 64 | + "Language | Added | Diff\n", |
| 71 | - " && gdown \"https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\" -O weight/diff/pytorch_model.bin" | 65 | + "--- | --- | ---\n", |
| 72 | - ], | 66 | + "python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n", |
| 73 | - "execution_count": null, | 67 | + "javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf" |
| 74 | - "outputs": [] | 68 | + ] |
| 75 | - }, | 69 | + }, |
| 76 | - { | 70 | + { |
| 77 | - "cell_type": "markdown", | 71 | + "cell_type": "code", |
| 78 | - "metadata": { | 72 | + "metadata": { |
| 79 | - "id": "org4Gqdv3iUu" | 73 | + "id": "P9-EBpxt0Dp0" |
| 80 | - }, | 74 | + }, |
| 81 | - "source": [ | 75 | + "source": [ |
| 82 | - "#### ngrok setting with flask\n", | 76 | + "ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n", |
| 83 | - "\n", | 77 | + "DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n", |
| 84 | - "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." | 78 | + "\n", |
| 85 | - ] | 79 | + "!pip install gdown \\\n", |
| 86 | - }, | 80 | + " && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n", |
| 87 | - { | 81 | + " && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin" |
| 88 | - "cell_type": "code", | 82 | + ], |
| 89 | - "metadata": { | 83 | + "execution_count": null, |
| 90 | - "id": "lZA3kuuG1Crj" | 84 | + "outputs": [] |
| 91 | - }, | 85 | + }, |
| 92 | - "source": [ | 86 | + { |
| 93 | - "!pip install flask-ngrok" | 87 | + "cell_type": "markdown", |
| 94 | - ], | 88 | + "metadata": { |
| 95 | - "execution_count": null, | 89 | + "id": "org4Gqdv3iUu" |
| 96 | - "outputs": [] | 90 | + }, |
| 97 | - }, | 91 | + "source": [ |
| 98 | - { | 92 | + "#### ngrok setting with flask\n", |
| 99 | - "cell_type": "markdown", | 93 | + "\n", |
| 100 | - "metadata": { | 94 | + "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail." |
| 101 | - "id": "hR78FRCMcqrZ" | 95 | + ] |
| 102 | - }, | 96 | + }, |
| 103 | - "source": [ | 97 | + { |
| 104 | - "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", | 98 | + "cell_type": "code", |
| 105 | - "\n" | 99 | + "metadata": { |
| 106 | - ] | 100 | + "id": "lZA3kuuG1Crj" |
| 107 | - }, | 101 | + }, |
| 108 | - { | 102 | + "source": [ |
| 109 | - "cell_type": "code", | 103 | + "!pip install flask-ngrok" |
| 110 | - "metadata": { | 104 | + ], |
| 111 | - "id": "L_mInbOKcoc2" | 105 | + "execution_count": null, |
| 112 | - }, | 106 | + "outputs": [] |
| 113 | - "source": [ | 107 | + }, |
| 114 | - "# Paste your authtoken here in quotes\n", | 108 | + { |
| 115 | - "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" | 109 | + "cell_type": "markdown", |
| 116 | - ], | 110 | + "metadata": { |
| 117 | - "execution_count": null, | 111 | + "id": "hR78FRCMcqrZ" |
| 118 | - "outputs": [] | 112 | + }, |
| 119 | - }, | 113 | + "source": [ |
| 120 | - { | 114 | + "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n", |
| 121 | - "cell_type": "markdown", | 115 | + "\n" |
| 122 | - "metadata": { | 116 | + ] |
| 123 | - "id": "QwCN4YFUc0M8" | 117 | + }, |
| 124 | - }, | 118 | + { |
| 125 | - "source": [ | 119 | + "cell_type": "code", |
| 126 | - "Set your region\n", | 120 | + "metadata": { |
| 127 | - "\n", | 121 | + "id": "L_mInbOKcoc2" |
| 128 | - "Code | Region\n", | 122 | + }, |
| 129 | - "--- | ---\n", | 123 | + "source": [ |
| 130 | - "us | United States\n", | 124 | + "# Paste your authtoken here in quotes\n", |
| 131 | - "eu | Europe\n", | 125 | + "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\"" |
| 132 | - "ap | Asia/Pacific\n", | 126 | + ], |
| 133 | - "au | Australia\n", | 127 | + "execution_count": null, |
| 134 | - "sa | South America\n", | 128 | + "outputs": [] |
| 135 | - "jp | Japan\n", | 129 | + }, |
| 136 | - "in | India" | 130 | + { |
| 137 | - ] | 131 | + "cell_type": "markdown", |
| 138 | - }, | 132 | + "metadata": { |
| 139 | - { | 133 | + "id": "QwCN4YFUc0M8" |
| 140 | - "cell_type": "code", | 134 | + }, |
| 141 | - "metadata": { | 135 | + "source": [ |
| 142 | - "id": "p4LSNN2xc0dQ" | 136 | + "Set your region\n", |
| 143 | - }, | 137 | + "\n", |
| 144 | - "source": [ | 138 | + "Code | Region\n", |
| 145 | - "# Set your region here in quotes\n", | 139 | + "--- | ---\n", |
| 146 | - "region = \"jp\"\n", | 140 | + "us | United States\n", |
| 147 | - "\n", | 141 | + "eu | Europe\n", |
| 148 | - "# Input and output ports for communication\n", | 142 | + "ap | Asia/Pacific\n", |
| 149 | - "local_in_port = 5000\n", | 143 | + "au | Australia\n", |
| 150 | - "local_out_port = 5000" | 144 | + "sa | South America\n", |
| 151 | - ], | 145 | + "jp | Japan\n", |
| 152 | - "execution_count": null, | 146 | + "in | India" |
| 153 | - "outputs": [] | 147 | + ] |
| 154 | - }, | 148 | + }, |
| 155 | - { | 149 | + { |
| 156 | - "cell_type": "code", | 150 | + "cell_type": "code", |
| 157 | - "metadata": { | 151 | + "metadata": { |
| 158 | - "id": "kg56PVrOdhi1" | 152 | + "id": "p4LSNN2xc0dQ" |
| 159 | - }, | 153 | + }, |
| 160 | - "source": [ | 154 | + "source": [ |
| 161 | - "config =\\\n", | 155 | + "# Set your region here in quotes\n", |
| 162 | - "f\"\"\"\n", | 156 | + "region = \"jp\"\n", |
| 163 | - "authtoken: {authtoken}\n", | 157 | + "\n", |
| 164 | - "region: {region}\n", | 158 | + "# Input and output ports for communication\n", |
| 165 | - "console_ui: False\n", | 159 | + "local_in_port = 5000\n", |
| 166 | - "tunnels:\n", | 160 | + "local_out_port = 5000" |
| 167 | - " input:\n", | 161 | + ], |
| 168 | - " addr: {local_in_port}\n", | 162 | + "execution_count": null, |
| 169 | - " proto: http \n", | 163 | + "outputs": [] |
| 170 | - " output:\n", | 164 | + }, |
| 171 | - " addr: {local_out_port}\n", | 165 | + { |
| 172 | - " proto: http\n", | 166 | + "cell_type": "code", |
| 173 | - "\"\"\"\n", | 167 | + "metadata": { |
| 174 | - "\n", | 168 | + "id": "kg56PVrOdhi1" |
| 175 | - "with open('ngrok.conf', 'w') as f:\n", | 169 | + }, |
| 176 | - " f.write(config)" | 170 | + "source": [ |
| 177 | - ], | 171 | + "config =\\\n", |
| 178 | - "execution_count": null, | 172 | + "f\"\"\"\n", |
| 179 | - "outputs": [] | 173 | + "authtoken: {authtoken}\n", |
| 180 | - }, | 174 | + "region: {region}\n", |
| 181 | - { | 175 | + "console_ui: False\n", |
| 182 | - "cell_type": "code", | 176 | + "tunnels:\n", |
| 183 | - "metadata": { | 177 | + " input:\n", |
| 184 | - "id": "hrWDrw_YdjIy" | 178 | + " addr: {local_in_port}\n", |
| 185 | - }, | 179 | + " proto: http \n", |
| 186 | - "source": [ | 180 | + " output:\n", |
| 187 | - "import time\n", | 181 | + " addr: {local_out_port}\n", |
| 188 | - "from subprocess import Popen, PIPE\n", | 182 | + " proto: http\n", |
| 189 | - "\n", | 183 | + "\"\"\"\n", |
| 190 | - "# (Re)Open tunnel\n", | 184 | + "\n", |
| 191 | - "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", | 185 | + "with open('ngrok.conf', 'w') as f:\n", |
| 192 | - "time.sleep(3)" | 186 | + " f.write(config)" |
| 193 | - ], | 187 | + ], |
| 194 | - "execution_count": null, | 188 | + "execution_count": null, |
| 195 | - "outputs": [] | 189 | + "outputs": [] |
| 196 | - }, | 190 | + }, |
| 197 | - { | 191 | + { |
| 198 | - "cell_type": "code", | 192 | + "cell_type": "code", |
| 199 | - "metadata": { | 193 | + "metadata": { |
| 200 | - "id": "pJgdFr0Fdjoq", | 194 | + "id": "hrWDrw_YdjIy" |
| 201 | - "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", | 195 | + }, |
| 202 | - "colab": { | 196 | + "source": [ |
| 203 | - "base_uri": "https://localhost:8080/" | 197 | + "import time\n", |
| 204 | - } | 198 | + "from subprocess import Popen, PIPE\n", |
| 205 | - }, | 199 | + "\n", |
| 206 | - "source": [ | 200 | + "# (Re)Open tunnel\n", |
| 207 | - "# Get tunnel addresses\n", | 201 | + "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", |
| 208 | - "try:\n", | 202 | + "time.sleep(3)" |
| 209 | - " in_addr, out_addr = get_tunnel_adresses()\n", | 203 | + ], |
| 210 | - " print(\"Tunnel opened\")\n", | 204 | + "execution_count": null, |
| 211 | - "except Exception as e:\n", | 205 | + "outputs": [] |
| 212 | - " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", | 206 | + }, |
| 213 | - " print(\"Something went wrong, reopen the tunnel\")" | 207 | + { |
| 214 | - ], | 208 | + "cell_type": "code", |
| 215 | - "execution_count": null, | 209 | + "metadata": { |
| 216 | - "outputs": [ | 210 | + "id": "pJgdFr0Fdjoq", |
| 217 | - { | 211 | + "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809", |
| 218 | - "output_type": "stream", | 212 | + "colab": { |
| 219 | - "text": [ | 213 | + "base_uri": "https://localhost:8080/" |
| 220 | - "Opening tunnel\n", | 214 | + } |
| 221 | - "Something went wrong, reopen the tunnel\n" | 215 | + }, |
| 222 | - ], | 216 | + "source": [ |
| 223 | - "name": "stdout" | 217 | + "# Get tunnel addresses\n", |
| 224 | - } | 218 | + "try:\n", |
| 225 | - ] | 219 | + " in_addr, out_addr = get_tunnel_adresses()\n", |
| 226 | - }, | 220 | + " print(\"Tunnel opened\")\n", |
| 227 | - { | 221 | + "except Exception as e:\n", |
| 228 | - "cell_type": "markdown", | 222 | + " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n", |
| 229 | - "metadata": { | 223 | + " print(\"Something went wrong, reopen the tunnel\")" |
| 230 | - "id": "cEZ-O0wz74OJ" | 224 | + ], |
| 231 | - }, | 225 | + "execution_count": null, |
| 232 | - "source": [ | 226 | + "outputs": [ |
| 233 | - "#### Run you server!" | ||
| 234 | - ] | ||
| 235 | - }, | ||
| 236 | - { | ||
| 237 | - "cell_type": "code", | ||
| 238 | - "metadata": { | ||
| 239 | - "id": "7PRkeYTL8Y_6" | ||
| 240 | - }, | ||
| 241 | - "source": [ | ||
| 242 | - "import os\n", | ||
| 243 | - "import torch\n", | ||
| 244 | - "import argparse\n", | ||
| 245 | - "from tqdm import tqdm\n", | ||
| 246 | - "import torch.nn as nn\n", | ||
| 247 | - "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", | ||
| 248 | - "from transformers import (RobertaConfig, RobertaTokenizer)\n", | ||
| 249 | - "\n", | ||
| 250 | - "from commit.model import Seq2Seq\n", | ||
| 251 | - "from commit.utils import (Example, convert_examples_to_features)\n", | ||
| 252 | - "from commit.model.diff_roberta import RobertaModel\n", | ||
| 253 | - "\n", | ||
| 254 | - "from flask import Flask, jsonify, request\n", | ||
| 255 | - "\n", | ||
| 256 | - "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}" | ||
| 257 | - ], | ||
| 258 | - "execution_count": null, | ||
| 259 | - "outputs": [] | ||
| 260 | - }, | ||
| 261 | - { | ||
| 262 | - "cell_type": "code", | ||
| 263 | - "metadata": { | ||
| 264 | - "id": "CiJKucX17qb4" | ||
| 265 | - }, | ||
| 266 | - "source": [ | ||
| 267 | - "def get_model(model_class, config, tokenizer, mode):\n", | ||
| 268 | - " encoder = model_class(config=config)\n", | ||
| 269 | - " decoder_layer = nn.TransformerDecoderLayer(\n", | ||
| 270 | - " d_model=config.hidden_size, nhead=config.num_attention_heads\n", | ||
| 271 | - " )\n", | ||
| 272 | - " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", | ||
| 273 | - " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n", | ||
| 274 | - " beam_size=args.beam_size, max_length=args.max_target_length,\n", | ||
| 275 | - " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n", | ||
| 276 | - "\n", | ||
| 277 | - " assert args.load_model_path\n", | ||
| 278 | - " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n", | ||
| 279 | - "\n", | ||
| 280 | - " model.load_state_dict(\n", | ||
| 281 | - " torch.load(\n", | ||
| 282 | - " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n", | ||
| 283 | - " map_location=torch.device('cpu')\n", | ||
| 284 | - " ),\n", | ||
| 285 | - " strict=False\n", | ||
| 286 | - " )\n", | ||
| 287 | - " return model\n", | ||
| 288 | - "\n", | ||
| 289 | - "def get_features(examples):\n", | ||
| 290 | - " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n", | ||
| 291 | - " all_source_ids = torch.tensor(\n", | ||
| 292 | - " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 293 | - " )\n", | ||
| 294 | - " all_source_mask = torch.tensor(\n", | ||
| 295 | - " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 296 | - " )\n", | ||
| 297 | - " all_patch_ids = torch.tensor(\n", | ||
| 298 | - " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 299 | - " )\n", | ||
| 300 | - " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n", | ||
| 301 | - "\n", | ||
| 302 | - "def create_app():\n", | ||
| 303 | - " @app.route('/')\n", | ||
| 304 | - " def index():\n", | ||
| 305 | - " return jsonify(hello=\"world\")\n", | ||
| 306 | - "\n", | ||
| 307 | - " @app.route('/added', methods=['POST'])\n", | ||
| 308 | - " def added():\n", | ||
| 309 | - " if request.method == 'POST':\n", | ||
| 310 | - " payload = request.get_json()\n", | ||
| 311 | - " example = [\n", | ||
| 312 | - " Example(\n", | ||
| 313 | - " idx=payload['idx'],\n", | ||
| 314 | - " added=payload['added'],\n", | ||
| 315 | - " deleted=payload['deleted'],\n", | ||
| 316 | - " target=None\n", | ||
| 317 | - " )\n", | ||
| 318 | - " ]\n", | ||
| 319 | - " message = inference(model=args.added_model, data=get_features(example))\n", | ||
| 320 | - " return jsonify(idx=payload['idx'], message=message)\n", | ||
| 321 | - "\n", | ||
| 322 | - " @app.route('/diff', methods=['POST'])\n", | ||
| 323 | - " def diff():\n", | ||
| 324 | - " if request.method == 'POST':\n", | ||
| 325 | - " payload = request.get_json()\n", | ||
| 326 | - " example = [\n", | ||
| 327 | - " Example(\n", | ||
| 328 | - " idx=payload['idx'],\n", | ||
| 329 | - " added=payload['added'],\n", | ||
| 330 | - " deleted=payload['deleted'],\n", | ||
| 331 | - " target=None\n", | ||
| 332 | - " )\n", | ||
| 333 | - " ]\n", | ||
| 334 | - " message = inference(model=args.diff_model, data=get_features(example))\n", | ||
| 335 | - " return jsonify(idx=payload['idx'], message=message)\n", | ||
| 336 | - "\n", | ||
| 337 | - " @app.route('/tokenizer', methods=['POST'])\n", | ||
| 338 | - " def tokenizer():\n", | ||
| 339 | - " if request.method == 'POST':\n", | ||
| 340 | - " payload = request.get_json()\n", | ||
| 341 | - " tokens = args.tokenizer.tokenize(payload['code'])\n", | ||
| 342 | - " return jsonify(tokens=tokens)\n", | ||
| 343 | - "\n", | ||
| 344 | - " return app\n", | ||
| 345 | - "\n", | ||
| 346 | - "def inference(model, data):\n", | ||
| 347 | - " # Calculate bleu\n", | ||
| 348 | - " eval_sampler = SequentialSampler(data)\n", | ||
| 349 | - " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n", | ||
| 350 | - "\n", | ||
| 351 | - " model.eval()\n", | ||
| 352 | - " p=[]\n", | ||
| 353 | - " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n", | ||
| 354 | - " batch = tuple(t.to(args.device) for t in batch)\n", | ||
| 355 | - " source_ids, source_mask, patch_ids = batch\n", | ||
| 356 | - " with torch.no_grad():\n", | ||
| 357 | - " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n", | ||
| 358 | - " for pred in preds:\n", | ||
| 359 | - " t = pred[0].cpu().numpy()\n", | ||
| 360 | - " t = list(t)\n", | ||
| 361 | - " if 0 in t:\n", | ||
| 362 | - " t = t[:t.index(0)]\n", | ||
| 363 | - " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n", | ||
| 364 | - " p.append(text)\n", | ||
| 365 | - " return p" | ||
| 366 | - ], | ||
| 367 | - "execution_count": null, | ||
| 368 | - "outputs": [] | ||
| 369 | - }, | ||
| 370 | - { | ||
| 371 | - "cell_type": "markdown", | ||
| 372 | - "metadata": { | ||
| 373 | - "id": "Esf4r-Ai8cG3" | ||
| 374 | - }, | ||
| 375 | - "source": [ | ||
| 376 | - "**Set enviroment**" | ||
| 377 | - ] | ||
| 378 | - }, | ||
| 379 | - { | ||
| 380 | - "cell_type": "code", | ||
| 381 | - "metadata": { | ||
| 382 | - "id": "mR7gVmSoSUoy" | ||
| 383 | - }, | ||
| 384 | - "source": [ | ||
| 385 | - "import easydict \n", | ||
| 386 | - "\n", | ||
| 387 | - "args = easydict.EasyDict({\n", | ||
| 388 | - " 'load_model_path': 'weight/', \n", | ||
| 389 | - " 'model_type': 'roberta',\n", | ||
| 390 | - " 'config_name' : 'microsoft/codebert-base',\n", | ||
| 391 | - " 'tokenizer_name' : 'microsoft/codebert-base',\n", | ||
| 392 | - " 'max_source_length' : 512,\n", | ||
| 393 | - " 'max_target_length' : 128,\n", | ||
| 394 | - " 'beam_size' : 10,\n", | ||
| 395 | - " 'do_lower_case' : False,\n", | ||
| 396 | - " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
| 397 | - "})" | ||
| 398 | - ], | ||
| 399 | - "execution_count": null, | ||
| 400 | - "outputs": [] | ||
| 401 | - }, | ||
| 402 | - { | ||
| 403 | - "cell_type": "code", | ||
| 404 | - "metadata": { | ||
| 405 | - "id": "e8dk5RwvToOv" | ||
| 406 | - }, | ||
| 407 | - "source": [ | ||
| 408 | - "# flask_ngrok_example.py\n", | ||
| 409 | - "from flask_ngrok import run_with_ngrok\n", | ||
| 410 | - "\n", | ||
| 411 | - "app = Flask(__name__)\n", | ||
| 412 | - "run_with_ngrok(app) # Start ngrok when app is run\n", | ||
| 413 | - "\n", | ||
| 414 | - "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n", | ||
| 415 | - "config = config_class.from_pretrained(args.config_name)\n", | ||
| 416 | - "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n", | ||
| 417 | - "\n", | ||
| 418 | - "# budild model\n", | ||
| 419 | - "args.added_model =get_model(model_class=model_class, config=config,\n", | ||
| 420 | - " tokenizer=args.tokenizer, mode='added').to(args.device)\n", | ||
| 421 | - "args.diff_model = get_model(model_class=model_class, config=config,\n", | ||
| 422 | - " tokenizer=args.tokenizer, mode='diff').to(args.device)\n", | ||
| 423 | - "\n", | ||
| 424 | - "app = create_app()\n", | ||
| 425 | - "app.run()" | ||
| 426 | - ], | ||
| 427 | - "execution_count": null, | ||
| 428 | - "outputs": [] | ||
| 429 | - }, | ||
| 430 | { | 227 | { |
| 431 | - "cell_type": "markdown", | 228 | + "output_type": "stream", |
| 432 | - "metadata": { | 229 | + "text": [ |
| 433 | - "id": "DXkBcO_sU_VN" | 230 | + "Opening tunnel\n", |
| 434 | - }, | 231 | + "Something went wrong, reopen the tunnel\n" |
| 435 | - "source": [ | 232 | + ], |
| 436 | - "## Set commit configure\n", | 233 | + "name": "stdout" |
| 437 | - "Now, set commit configure on your local computer.\n", | ||
| 438 | - "```shell\n", | ||
| 439 | - "$ commit configure --endpoint http://********.ngrok.io\n", | ||
| 440 | - "```" | ||
| 441 | - ] | ||
| 442 | } | 234 | } |
| 443 | - ] | 235 | + ] |
| 236 | + }, | ||
| 237 | + { | ||
| 238 | + "cell_type": "markdown", | ||
| 239 | + "metadata": { | ||
| 240 | + "id": "cEZ-O0wz74OJ" | ||
| 241 | + }, | ||
| 242 | + "source": [ | ||
| 243 | + "#### Run you server!" | ||
| 244 | + ] | ||
| 245 | + }, | ||
| 246 | + { | ||
| 247 | + "cell_type": "code", | ||
| 248 | + "metadata": { | ||
| 249 | + "id": "7PRkeYTL8Y_6" | ||
| 250 | + }, | ||
| 251 | + "source": [ | ||
| 252 | + "import os\n", | ||
| 253 | + "import torch\n", | ||
| 254 | + "import argparse\n", | ||
| 255 | + "from tqdm import tqdm\n", | ||
| 256 | + "import torch.nn as nn\n", | ||
| 257 | + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n", | ||
| 258 | + "from transformers import (RobertaConfig, RobertaTokenizer)\n", | ||
| 259 | + "\n", | ||
| 260 | + "from commit.model import Seq2Seq\n", | ||
| 261 | + "from commit.utils import (Example, convert_examples_to_features)\n", | ||
| 262 | + "from commit.model.diff_roberta import RobertaModel\n", | ||
| 263 | + "\n", | ||
| 264 | + "from flask import Flask, jsonify, request\n", | ||
| 265 | + "\n", | ||
| 266 | + "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}" | ||
| 267 | + ], | ||
| 268 | + "execution_count": null, | ||
| 269 | + "outputs": [] | ||
| 270 | + }, | ||
| 271 | + { | ||
| 272 | + "cell_type": "code", | ||
| 273 | + "metadata": { | ||
| 274 | + "id": "CiJKucX17qb4" | ||
| 275 | + }, | ||
| 276 | + "source": [ | ||
| 277 | + "def get_model(model_class, config, tokenizer, mode):\n", | ||
| 278 | + " encoder = model_class(config=config)\n", | ||
| 279 | + " decoder_layer = nn.TransformerDecoderLayer(\n", | ||
| 280 | + " d_model=config.hidden_size, nhead=config.num_attention_heads\n", | ||
| 281 | + " )\n", | ||
| 282 | + " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n", | ||
| 283 | + " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n", | ||
| 284 | + " beam_size=args.beam_size, max_length=args.max_target_length,\n", | ||
| 285 | + " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n", | ||
| 286 | + "\n", | ||
| 287 | + " assert args.load_model_path\n", | ||
| 288 | + " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n", | ||
| 289 | + "\n", | ||
| 290 | + " model.load_state_dict(\n", | ||
| 291 | + " torch.load(\n", | ||
| 292 | + " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n", | ||
| 293 | + " map_location=torch.device('cpu')\n", | ||
| 294 | + " ),\n", | ||
| 295 | + " strict=False\n", | ||
| 296 | + " )\n", | ||
| 297 | + " return model\n", | ||
| 298 | + "\n", | ||
| 299 | + "def get_features(examples):\n", | ||
| 300 | + " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n", | ||
| 301 | + " all_source_ids = torch.tensor(\n", | ||
| 302 | + " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 303 | + " )\n", | ||
| 304 | + " all_source_mask = torch.tensor(\n", | ||
| 305 | + " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 306 | + " )\n", | ||
| 307 | + " all_patch_ids = torch.tensor(\n", | ||
| 308 | + " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n", | ||
| 309 | + " )\n", | ||
| 310 | + " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n", | ||
| 311 | + "\n", | ||
| 312 | + "def create_app():\n", | ||
| 313 | + " @app.route('/')\n", | ||
| 314 | + " def index():\n", | ||
| 315 | + " return jsonify(hello=\"world\")\n", | ||
| 316 | + "\n", | ||
| 317 | + " @app.route('/added', methods=['POST'])\n", | ||
| 318 | + " def added():\n", | ||
| 319 | + " if request.method == 'POST':\n", | ||
| 320 | + " payload = request.get_json()\n", | ||
| 321 | + " example = [\n", | ||
| 322 | + " Example(\n", | ||
| 323 | + " idx=payload['idx'],\n", | ||
| 324 | + " added=payload['added'],\n", | ||
| 325 | + " deleted=payload['deleted'],\n", | ||
| 326 | + " target=None\n", | ||
| 327 | + " )\n", | ||
| 328 | + " ]\n", | ||
| 329 | + " message = inference(model=args.added_model, data=get_features(example))\n", | ||
| 330 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
| 331 | + "\n", | ||
| 332 | + " @app.route('/diff', methods=['POST'])\n", | ||
| 333 | + " def diff():\n", | ||
| 334 | + " if request.method == 'POST':\n", | ||
| 335 | + " payload = request.get_json()\n", | ||
| 336 | + " example = [\n", | ||
| 337 | + " Example(\n", | ||
| 338 | + " idx=payload['idx'],\n", | ||
| 339 | + " added=payload['added'],\n", | ||
| 340 | + " deleted=payload['deleted'],\n", | ||
| 341 | + " target=None\n", | ||
| 342 | + " )\n", | ||
| 343 | + " ]\n", | ||
| 344 | + " message = inference(model=args.diff_model, data=get_features(example))\n", | ||
| 345 | + " return jsonify(idx=payload['idx'], message=message)\n", | ||
| 346 | + "\n", | ||
| 347 | + " @app.route('/tokenizer', methods=['POST'])\n", | ||
| 348 | + " def tokenizer():\n", | ||
| 349 | + " if request.method == 'POST':\n", | ||
| 350 | + " payload = request.get_json()\n", | ||
| 351 | + " tokens = args.tokenizer.tokenize(payload['code'])\n", | ||
| 352 | + " return jsonify(tokens=tokens)\n", | ||
| 353 | + "\n", | ||
| 354 | + " return app\n", | ||
| 355 | + "\n", | ||
| 356 | + "def inference(model, data):\n", | ||
| 357 | + " # Calculate bleu\n", | ||
| 358 | + " eval_sampler = SequentialSampler(data)\n", | ||
| 359 | + " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n", | ||
| 360 | + "\n", | ||
| 361 | + " model.eval()\n", | ||
| 362 | + " p=[]\n", | ||
| 363 | + " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n", | ||
| 364 | + " batch = tuple(t.to(args.device) for t in batch)\n", | ||
| 365 | + " source_ids, source_mask, patch_ids = batch\n", | ||
| 366 | + " with torch.no_grad():\n", | ||
| 367 | + " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n", | ||
| 368 | + " for pred in preds:\n", | ||
| 369 | + " t = pred[0].cpu().numpy()\n", | ||
| 370 | + " t = list(t)\n", | ||
| 371 | + " if 0 in t:\n", | ||
| 372 | + " t = t[:t.index(0)]\n", | ||
| 373 | + " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n", | ||
| 374 | + " p.append(text)\n", | ||
| 375 | + " return p" | ||
| 376 | + ], | ||
| 377 | + "execution_count": null, | ||
| 378 | + "outputs": [] | ||
| 379 | + }, | ||
| 380 | + { | ||
| 381 | + "cell_type": "markdown", | ||
| 382 | + "metadata": { | ||
| 383 | + "id": "Esf4r-Ai8cG3" | ||
| 384 | + }, | ||
| 385 | + "source": [ | ||
| 386 | + "**Set enviroment**" | ||
| 387 | + ] | ||
| 388 | + }, | ||
| 389 | + { | ||
| 390 | + "cell_type": "code", | ||
| 391 | + "metadata": { | ||
| 392 | + "id": "mR7gVmSoSUoy" | ||
| 393 | + }, | ||
| 394 | + "source": [ | ||
| 395 | + "import easydict \n", | ||
| 396 | + "\n", | ||
| 397 | + "args = easydict.EasyDict({\n", | ||
| 398 | + " 'load_model_path': 'weight/', \n", | ||
| 399 | + " 'model_type': 'roberta',\n", | ||
| 400 | + " 'config_name' : 'microsoft/codebert-base',\n", | ||
| 401 | + " 'tokenizer_name' : 'microsoft/codebert-base',\n", | ||
| 402 | + " 'max_source_length' : 512,\n", | ||
| 403 | + " 'max_target_length' : 128,\n", | ||
| 404 | + " 'beam_size' : 10,\n", | ||
| 405 | + " 'do_lower_case' : False,\n", | ||
| 406 | + " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
| 407 | + "})" | ||
| 408 | + ], | ||
| 409 | + "execution_count": null, | ||
| 410 | + "outputs": [] | ||
| 411 | + }, | ||
| 412 | + { | ||
| 413 | + "cell_type": "code", | ||
| 414 | + "metadata": { | ||
| 415 | + "id": "e8dk5RwvToOv" | ||
| 416 | + }, | ||
| 417 | + "source": [ | ||
| 418 | + "# flask_ngrok_example.py\n", | ||
| 419 | + "from flask_ngrok import run_with_ngrok\n", | ||
| 420 | + "\n", | ||
| 421 | + "app = Flask(__name__)\n", | ||
| 422 | + "run_with_ngrok(app) # Start ngrok when app is run\n", | ||
| 423 | + "\n", | ||
| 424 | + "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n", | ||
| 425 | + "config = config_class.from_pretrained(args.config_name)\n", | ||
| 426 | + "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n", | ||
| 427 | + "\n", | ||
| 428 | + "# budild model\n", | ||
| 429 | + "args.added_model =get_model(model_class=model_class, config=config,\n", | ||
| 430 | + " tokenizer=args.tokenizer, mode='added').to(args.device)\n", | ||
| 431 | + "args.diff_model = get_model(model_class=model_class, config=config,\n", | ||
| 432 | + " tokenizer=args.tokenizer, mode='diff').to(args.device)\n", | ||
| 433 | + "\n", | ||
| 434 | + "app = create_app()\n", | ||
| 435 | + "app.run()" | ||
| 436 | + ], | ||
| 437 | + "execution_count": null, | ||
| 438 | + "outputs": [] | ||
| 439 | + }, | ||
| 440 | + { | ||
| 441 | + "cell_type": "markdown", | ||
| 442 | + "metadata": { | ||
| 443 | + "id": "DXkBcO_sU_VN" | ||
| 444 | + }, | ||
| 445 | + "source": [ | ||
| 446 | + "## Set commit configure\n", | ||
| 447 | + "Now, set commit configure on your local computer.\n", | ||
| 448 | + "```shell\n", | ||
| 449 | + "$ commit configure --endpoint http://********.ngrok.io\n", | ||
| 450 | + "```" | ||
| 451 | + ] | ||
| 452 | + } | ||
| 453 | + ] | ||
| 444 | } | 454 | } |
| ... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment