graykode

(add) javascript support in google colab

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "commit-autosuggestions.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "commit-autosuggestions.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DZ7rFp2gzuNS"
},
"source": [
"## Start commit-autosuggestions server\n",
"Running flask app server in Google Colab for people without GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d8Lyin2I3wHq"
},
"source": [
"#### Clone github repository"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e_cu9igvzjcs"
},
"source": [
"!git clone https://github.com/graykode/commit-autosuggestions.git\n",
"%cd commit-autosuggestions\n",
"!pip install -r requirements.txt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PFKn5QZr0dQx"
},
"source": [
"#### Download model weights\n",
"\n",
"Download the two weights of model from the google drive through the gdown module.\n",
"1. [Added model](https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4) : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
"2. [Diff model](https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m) : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code."
]
},
{
"cell_type": "code",
"metadata": {
"id": "P9-EBpxt0Dp0"
},
"source": [
"!pip install gdown \\\n",
" && gdown \"https://drive.google.com/uc?id=1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4\" -O weight/added/pytorch_model.bin \\\n",
" && gdown \"https://drive.google.com/uc?id=1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\" -O weight/diff/pytorch_model.bin"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "org4Gqdv3iUu"
},
"source": [
"#### ngrok setting with flask\n",
"\n",
"Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lZA3kuuG1Crj"
},
"source": [
"!pip install flask-ngrok"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hR78FRCMcqrZ"
},
"source": [
"Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L_mInbOKcoc2"
},
"source": [
"# Paste your authtoken here in quotes\n",
"authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwCN4YFUc0M8"
},
"source": [
"Set your region\n",
"\n",
"Code | Region\n",
"--- | ---\n",
"us | United States\n",
"eu | Europe\n",
"ap | Asia/Pacific\n",
"au | Australia\n",
"sa | South America\n",
"jp | Japan\n",
"in | India"
]
},
{
"cell_type": "code",
"metadata": {
"id": "p4LSNN2xc0dQ"
},
"source": [
"# Set your region here in quotes\n",
"region = \"jp\"\n",
"\n",
"# Input and output ports for communication\n",
"local_in_port = 5000\n",
"local_out_port = 5000"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kg56PVrOdhi1"
},
"source": [
"config =\\\n",
"f\"\"\"\n",
"authtoken: {authtoken}\n",
"region: {region}\n",
"console_ui: False\n",
"tunnels:\n",
" input:\n",
" addr: {local_in_port}\n",
" proto: http \n",
" output:\n",
" addr: {local_out_port}\n",
" proto: http\n",
"\"\"\"\n",
"\n",
"with open('ngrok.conf', 'w') as f:\n",
" f.write(config)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hrWDrw_YdjIy"
},
"source": [
"import time\n",
"from subprocess import Popen, PIPE\n",
"\n",
"# (Re)Open tunnel\n",
"ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
"time.sleep(3)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pJgdFr0Fdjoq",
"outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# Get tunnel addresses\n",
"try:\n",
" in_addr, out_addr = get_tunnel_adresses()\n",
" print(\"Tunnel opened\")\n",
"except Exception as e:\n",
" [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
" print(\"Something went wrong, reopen the tunnel\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Opening tunnel\n",
"Something went wrong, reopen the tunnel\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cEZ-O0wz74OJ"
},
"source": [
"#### Run you server!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7PRkeYTL8Y_6"
},
"source": [
"import os\n",
"import torch\n",
"import argparse\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
"from transformers import (RobertaConfig, RobertaTokenizer)\n",
"\n",
"from commit.model import Seq2Seq\n",
"from commit.utils import (Example, convert_examples_to_features)\n",
"from commit.model.diff_roberta import RobertaModel\n",
"\n",
"from flask import Flask, jsonify, request\n",
"\n",
"MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CiJKucX17qb4"
},
"source": [
"def get_model(model_class, config, tokenizer, mode):\n",
" encoder = model_class(config=config)\n",
" decoder_layer = nn.TransformerDecoderLayer(\n",
" d_model=config.hidden_size, nhead=config.num_attention_heads\n",
" )\n",
" decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
" model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
" beam_size=args.beam_size, max_length=args.max_target_length,\n",
" sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
"\n",
" assert args.load_model_path\n",
" assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
"\n",
" model.load_state_dict(\n",
" torch.load(\n",
" os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
" map_location=torch.device('cpu')\n",
" ),\n",
" strict=False\n",
" )\n",
" return model\n",
"\n",
"def get_features(examples):\n",
" features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
" all_source_ids = torch.tensor(\n",
" [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_source_mask = torch.tensor(\n",
" [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_patch_ids = torch.tensor(\n",
" [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
"\n",
"def create_app():\n",
" @app.route('/')\n",
" def index():\n",
" return jsonify(hello=\"world\")\n",
"\n",
" @app.route('/added', methods=['POST'])\n",
" def added():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.added_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/diff', methods=['POST'])\n",
" def diff():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.diff_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/tokenizer', methods=['POST'])\n",
" def tokenizer():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" tokens = args.tokenizer.tokenize(payload['code'])\n",
" return jsonify(tokens=tokens)\n",
"\n",
" return app\n",
"\n",
"def inference(model, data):\n",
" # Calculate bleu\n",
" eval_sampler = SequentialSampler(data)\n",
" eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
"\n",
" model.eval()\n",
" p=[]\n",
" for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
" batch = tuple(t.to(args.device) for t in batch)\n",
" source_ids, source_mask, patch_ids = batch\n",
" with torch.no_grad():\n",
" preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
" for pred in preds:\n",
" t = pred[0].cpu().numpy()\n",
" t = list(t)\n",
" if 0 in t:\n",
" t = t[:t.index(0)]\n",
" text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
" p.append(text)\n",
" return p"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Esf4r-Ai8cG3"
},
"source": [
"**Set enviroment**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mR7gVmSoSUoy"
},
"source": [
"import easydict \n",
"\n",
"args = easydict.EasyDict({\n",
" 'load_model_path': 'weight/', \n",
" 'model_type': 'roberta',\n",
" 'config_name' : 'microsoft/codebert-base',\n",
" 'tokenizer_name' : 'microsoft/codebert-base',\n",
" 'max_source_length' : 512,\n",
" 'max_target_length' : 128,\n",
" 'beam_size' : 10,\n",
" 'do_lower_case' : False,\n",
" 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "e8dk5RwvToOv"
},
"source": [
"# flask_ngrok_example.py\n",
"from flask_ngrok import run_with_ngrok\n",
"\n",
"app = Flask(__name__)\n",
"run_with_ngrok(app) # Start ngrok when app is run\n",
"\n",
"config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
"config = config_class.from_pretrained(args.config_name)\n",
"args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
"\n",
"# budild model\n",
"args.added_model =get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='added').to(args.device)\n",
"args.diff_model = get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
"\n",
"app = create_app()\n",
"app.run()"
],
"execution_count": null,
"outputs": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DZ7rFp2gzuNS"
},
"source": [
"## Start commit-autosuggestions server\n",
"Running flask app server in Google Colab for people without GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d8Lyin2I3wHq"
},
"source": [
"#### Clone github repository"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e_cu9igvzjcs"
},
"source": [
"!git clone https://github.com/graykode/commit-autosuggestions.git\n",
"%cd commit-autosuggestions\n",
"!pip install -r requirements.txt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PFKn5QZr0dQx"
},
"source": [
"#### Download model weights\n",
"\n",
"Download the two weights of model from the google drive through the gdown module.\n",
"1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
"2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
"\n",
"Download pre-trained weight\n",
"\n",
"Language | Added | Diff\n",
"--- | --- | ---\n",
"python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
"javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P9-EBpxt0Dp0"
},
"source": [
"ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
"DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
"\n",
"!pip install gdown \\\n",
" && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
" && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "org4Gqdv3iUu"
},
"source": [
"#### ngrok setting with flask\n",
"\n",
"Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lZA3kuuG1Crj"
},
"source": [
"!pip install flask-ngrok"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hR78FRCMcqrZ"
},
"source": [
"Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L_mInbOKcoc2"
},
"source": [
"# Paste your authtoken here in quotes\n",
"authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwCN4YFUc0M8"
},
"source": [
"Set your region\n",
"\n",
"Code | Region\n",
"--- | ---\n",
"us | United States\n",
"eu | Europe\n",
"ap | Asia/Pacific\n",
"au | Australia\n",
"sa | South America\n",
"jp | Japan\n",
"in | India"
]
},
{
"cell_type": "code",
"metadata": {
"id": "p4LSNN2xc0dQ"
},
"source": [
"# Set your region here in quotes\n",
"region = \"jp\"\n",
"\n",
"# Input and output ports for communication\n",
"local_in_port = 5000\n",
"local_out_port = 5000"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kg56PVrOdhi1"
},
"source": [
"config =\\\n",
"f\"\"\"\n",
"authtoken: {authtoken}\n",
"region: {region}\n",
"console_ui: False\n",
"tunnels:\n",
" input:\n",
" addr: {local_in_port}\n",
" proto: http \n",
" output:\n",
" addr: {local_out_port}\n",
" proto: http\n",
"\"\"\"\n",
"\n",
"with open('ngrok.conf', 'w') as f:\n",
" f.write(config)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hrWDrw_YdjIy"
},
"source": [
"import time\n",
"from subprocess import Popen, PIPE\n",
"\n",
"# (Re)Open tunnel\n",
"ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
"time.sleep(3)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pJgdFr0Fdjoq",
"outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# Get tunnel addresses\n",
"try:\n",
" in_addr, out_addr = get_tunnel_adresses()\n",
" print(\"Tunnel opened\")\n",
"except Exception as e:\n",
" [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
" print(\"Something went wrong, reopen the tunnel\")"
],
"execution_count": null,
"outputs": [
{
"cell_type": "markdown",
"metadata": {
"id": "DXkBcO_sU_VN"
},
"source": [
"## Set commit configure\n",
"Now, set commit configure on your local computer.\n",
"```shell\n",
"$ commit configure --endpoint http://********.ngrok.io\n",
"```"
]
"output_type": "stream",
"text": [
"Opening tunnel\n",
"Something went wrong, reopen the tunnel\n"
],
"name": "stdout"
}
]
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cEZ-O0wz74OJ"
},
"source": [
"#### Run you server!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7PRkeYTL8Y_6"
},
"source": [
"import os\n",
"import torch\n",
"import argparse\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
"from transformers import (RobertaConfig, RobertaTokenizer)\n",
"\n",
"from commit.model import Seq2Seq\n",
"from commit.utils import (Example, convert_examples_to_features)\n",
"from commit.model.diff_roberta import RobertaModel\n",
"\n",
"from flask import Flask, jsonify, request\n",
"\n",
"MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CiJKucX17qb4"
},
"source": [
"def get_model(model_class, config, tokenizer, mode):\n",
" encoder = model_class(config=config)\n",
" decoder_layer = nn.TransformerDecoderLayer(\n",
" d_model=config.hidden_size, nhead=config.num_attention_heads\n",
" )\n",
" decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
" model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
" beam_size=args.beam_size, max_length=args.max_target_length,\n",
" sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
"\n",
" assert args.load_model_path\n",
" assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
"\n",
" model.load_state_dict(\n",
" torch.load(\n",
" os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
" map_location=torch.device('cpu')\n",
" ),\n",
" strict=False\n",
" )\n",
" return model\n",
"\n",
"def get_features(examples):\n",
" features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
" all_source_ids = torch.tensor(\n",
" [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_source_mask = torch.tensor(\n",
" [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_patch_ids = torch.tensor(\n",
" [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
"\n",
"def create_app():\n",
" @app.route('/')\n",
" def index():\n",
" return jsonify(hello=\"world\")\n",
"\n",
" @app.route('/added', methods=['POST'])\n",
" def added():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.added_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/diff', methods=['POST'])\n",
" def diff():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.diff_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/tokenizer', methods=['POST'])\n",
" def tokenizer():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" tokens = args.tokenizer.tokenize(payload['code'])\n",
" return jsonify(tokens=tokens)\n",
"\n",
" return app\n",
"\n",
"def inference(model, data):\n",
" # Calculate bleu\n",
" eval_sampler = SequentialSampler(data)\n",
" eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
"\n",
" model.eval()\n",
" p=[]\n",
" for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
" batch = tuple(t.to(args.device) for t in batch)\n",
" source_ids, source_mask, patch_ids = batch\n",
" with torch.no_grad():\n",
" preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
" for pred in preds:\n",
" t = pred[0].cpu().numpy()\n",
" t = list(t)\n",
" if 0 in t:\n",
" t = t[:t.index(0)]\n",
" text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
" p.append(text)\n",
" return p"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Esf4r-Ai8cG3"
},
"source": [
"**Set enviroment**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mR7gVmSoSUoy"
},
"source": [
"import easydict \n",
"\n",
"args = easydict.EasyDict({\n",
" 'load_model_path': 'weight/', \n",
" 'model_type': 'roberta',\n",
" 'config_name' : 'microsoft/codebert-base',\n",
" 'tokenizer_name' : 'microsoft/codebert-base',\n",
" 'max_source_length' : 512,\n",
" 'max_target_length' : 128,\n",
" 'beam_size' : 10,\n",
" 'do_lower_case' : False,\n",
" 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "e8dk5RwvToOv"
},
"source": [
"# flask_ngrok_example.py\n",
"from flask_ngrok import run_with_ngrok\n",
"\n",
"app = Flask(__name__)\n",
"run_with_ngrok(app) # Start ngrok when app is run\n",
"\n",
"config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
"config = config_class.from_pretrained(args.config_name)\n",
"args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
"\n",
"# budild model\n",
"args.added_model =get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='added').to(args.device)\n",
"args.diff_model = get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
"\n",
"app = create_app()\n",
"app.run()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DXkBcO_sU_VN"
},
"source": [
"## Set commit configure\n",
"Now, set commit configure on your local computer.\n",
"```shell\n",
"$ commit configure --endpoint http://********.ngrok.io\n",
"```"
]
}
]
}
\ No newline at end of file
......