About

Hi there, in this post you'll learn how to finetune the a RoBERT based model that's been trained on code data to automatically generate comments for code!

We will be focusing on the Java programming language, but you can apply the same techniques in this post for any programming language that interests you. Additionally, you'll see how to incorporate this code commenter into a VSCode extension so that you can generate comments for code snippets you highlight:

(Insert GIF of tool working)

As always, we'll start with a bit of background of the data and model we are using, but feel free to skip if you want to get straight to the awesomeness ;). Alright, let's GO!

Background

Data

We will be using the awesome CodeSearchNet Challenge dataset, which contains millions of pairs of methods and their docstrings for a large variety of programming languages. The dataset was initially constructed for evaluating how well different approaches perform at searching for code. However, we can easily repurpose it for us and lucky for us, the awesome authors did an awesome job collecting, documenting, and cleaning the data.

We'll be performing a bit more cleaning and formatting of the data as well as adding some more examples. These examples won't be method/docstring pairs, but code snippet/inline comment pairs. This allows our model to generate comments for arbitrary code snippets that a developer may want to document instead of just generating the docstring of a method.

CodeBERT

The pretrained model we will be finetuning comes from the awesome paper from Microsoft's research division aptly named CodeBERT: A Pre-Trained Model for Programming and Natural Languages. This model also used the CodeSearchNet challenge dataset, but instead of using it to generate comments it used to teach a RoBERTa based model to represent code and natural language in a useful way. This practice of eaching these large language models to represent text in a useful way is common practice now since these representations have been shown to be helpful in finetuning these models on other tasks. The CodeBERT paper showed these representations are helpful by finetuning them on the programming task of code search and comment generation, exactly what we will be doing! The difference between their comment generation task and ours is that we will do a bit more preprocessing and our model will be able to generate inline comments of code snippets and not just method level comments.

So, how does CodeBERT learn these representations? It combines two different training objectives that's been shown to be useful for natural language. The Masked Language Modeling objective (MLM), which is from the original BERT paper, and Replaced Token Detection (RTD) objective, which is from the ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators paper. The MLM objective is where we randomly mask out parts of the text that we feed into the model and ask the model to predict those masked out pieces. The RTD objective is where random tokens in the text are replaced and the model has to determine which of these tokens are replaced. However, to make it harder for the model, these replaced tokens attempt to be plausible alternatives and not just random words. The CodeBERT model actually used a n-gram based model to generate these alternatives where as the ELECTRA paper used a small BERT based model.

ELECTRA Pretraining Objective (From ELECTRA Paper)

Instead of using only natural language to apply these training objectives to, CodeBERT used code and docstrings. This allowed the CodeBERT model to learn a useful representation of code that could be used for other tasks.

Alright with that quick background knowledge down, lets get into actually finetuning our model!

! nvidia-smi
Thu Jan 14 20:43:12 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Data

First we'll install the necessary packages and download our data!

# Download and install the necessary dependencies
! pip install -q torch==1.4.0 -f https://download.pytorch.org/whl/cu101/torch_stable.html
! pip install -q transformers==3.5.0 fast-trees

! git clone -q https://github.com/microsoft/CodeXGLUE.git

# Download the CodeSearchNet Challenge dataset for the Java programming language
! wget -q https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip
! unzip -qq java.zip
     |████████████████████████████████| 753.4MB 21kB/s 
ERROR: torchvision 0.8.1+cu101 has requirement torch==1.7.0, but you'll have torch 1.4.0 which is incompatible.
     |████████████████████████████████| 1.3MB 12.6MB/s 
     |████████████████████████████████| 890kB 51.7MB/s 
     |████████████████████████████████| 2.9MB 50.3MB/s 
     |████████████████████████████████| 1.1MB 61.3MB/s 
     |████████████████████████████████| 112kB 64.8MB/s 
     |████████████████████████████████| 163kB 60.4MB/s 
     |████████████████████████████████| 71kB 11.8MB/s 
  Building wheel for sacremoses (setup.py) ... done
  Building wheel for tree-sitter (setup.py) ... done

Next let's read in our data and since these models take a long time to train, we will only select a subset of the data.

import pandas as pd

from pathlib import Path
from typing import List, Optional

# Code from CodeSearchNetChallenge: https://github.com/github/CodeSearchNet/blob/master/notebooks/ExploreData.ipynb
def jsonl_list_to_dataframe(file_list, columns=['code', 'docstring']):
    """Load a list of jsonl.gz files into a pandas DataFrame."""
    return pd.concat([pd.read_json(f,
                                   orient='records', 
                                   compression='gzip',
                                   lines=True)[columns] 
                      for f in file_list], sort=False)

def get_dfs(path: Path) -> List[pd.DataFrame]:
    """Grabs the different data splits and converts them into dataframes"""
    dfs = []
    for split in ["train", "valid", "test"]:
        files = sorted((path/split).glob("**/*.gz"))
        df = jsonl_list_to_dataframe(files).rename(columns = {'code': 'mthd', 'docstring': 'cmt'})
        dfs.append(df)
        
    return dfs

path = Path('.')
df_trn, df_val, df_tst = get_dfs(path/"java/final/jsonl")
sample = 0.01
df_trn = df_trn.sample(frac = sample)
df_val = df_val.sample(frac = sample)
df_tst = df_tst.sample(frac = sample)

len(df_trn), len(df_val), len(df_tst)
(4545, 153, 269)

Let's see how the data looks. As shown, we have the data in a good format with one column all of the methods (input into the model) and the other all of the comments (output of the model).

df_trn.head()
mthd cmt
5360 @Override\n public GetLexiconResult getLexi... <p>\nReturns the content of the specified pron...
9365 public static void checkJavaInternalAccess(ILo... Prints warning to given {@link ILogger} if Haz...
10145 private IAtom createAtom(Element element) {\n ... Create a new atom for the provided symbol. The...
9008 public void marshall(Scte20PlusEmbeddedDestina... Marshall the given parameter object.
24498 @Override\n public void prefetchToken(final F... /*\nGets hadoop tokens for a user to run mapre...

Data Cleaning

Now, that we have the data, let's clean it! First, we'll remove any non-ascii characters to simplify the problem so that the model only has to think about generating English comments.

# From https://stackoverflow.com/a/27084708/5768407
def is_ascii(s):
    '''
    Determines if the given string contains only ascii characters

    :param s: the string to check
    :returns: whether or not the given string contains only ascii characters
    '''
    try:
        s.encode(encoding='utf-8').decode('ascii')
    except UnicodeDecodeError:
        return False
    else:
        return True

df_trn = df_trn[df_trn['mthd'].apply(lambda x: is_ascii(x))]
df_val = df_val[df_val['mthd'].apply(lambda x: is_ascii(x))]
df_tst = df_tst[df_tst['mthd'].apply(lambda x: is_ascii(x))]

df_trn = df_trn[df_trn['cmt'].apply(lambda x: is_ascii(x))]
df_val = df_val[df_val['cmt'].apply(lambda x: is_ascii(x))]
df_tst = df_tst[df_tst['cmt'].apply(lambda x: is_ascii(x))]

len(df_trn), len(df_val), len(df_tst)
(4402, 141, 264)

Next, we'll remove any outdated comments by checking to see if the JavaDoc's parameter list is different from the method's parameter list. This also will remove pairs where the docstring doesn't actually document the parameters, which probably means the pairs are poor quality (you should always properly document your code :) ).

import re

from fast_trees.core import FastParser

parser = FastParser('java')

def get_cmt_params(cmt: str) -> List[str]:
    '''
    Grabs the parameter identifier names from a JavaDoc comment

    :param cmt: the comment to extract the parameter identifier names from
    :returns: an array of the parameter identifier names found in the given comment
    '''
    params = re.findall('@param+\s+\w+', cmt)
    param_names = []
    for param in params:
        param_names.append(param.split()[1])
    
    return param_names

def is_outdated(mthd: str, cmt: str, parser: FastParser) -> bool:
    '''
    Determines if a given method and comment are outdated by checking
    if the method's parameter identifier names match the comment's

    :param mthd: the method to compare against its corresponding comment
    :param cmt: the comment to compare against its corresponding method
    :param parser: parser for easily getting the parameter identifier names from a given method
    :returns: wheather or not a given comment is outdated compared to its corresponding method
    '''
    try:
        mthd_params = parser.get_params(mthd)
    except:
        return False
    
    cmt_params = get_cmt_params(cmt)

    return mthd_params != cmt_params

df_trn = df_trn[
    ~df_trn.apply(
        lambda x: is_outdated(x.mthd, x.cmt, parser), axis = 1
    )
]
df_val = df_val[
    ~df_val.apply(
        lambda x: is_outdated(x.mthd, x.cmt, parser), axis = 1
    )
]
df_tst = df_tst[
    ~df_tst.apply(
        lambda x: is_outdated(x.mthd, x.cmt, parser), axis = 1
    )
]

len(df_trn), len(df_val), len(df_tst)
Downloading repo https://github.com/tree-sitter/tree-sitter-java to /usr/local/lib/python3.6/dist-packages/fast_trees/tree-sitter-java.
(4402, 141, 264)

Now we'll add in the additional pairs of code snippets/inline comments.

P.S. One thing to note with adding these pairs is that the inline comments will appear twice in the datasets. The first in the method where the inline comment came from and the second in the target for the code snippet. This is only a problem for the training set since it allows for the model to cheat by simply remembering the inline comment from the example method it came from. However, in my testing, I found this to not be an issue and the model seems to still work well despite this problem. Just thought ya should know :).

from tqdm.auto import tqdm

def get_inline_pairs(mthd):
    '''
    Get all pairs of inline comments and corresponding code snippets

    :param mthd: the method to retrieve the pairs of comments and corresponding
    code snippets from
    :returns: all pairs of comments and corresponding code snippets
    '''
    pairs = [[]]

    comment = False
    bracket = False
    indent_lvl = -1
    lines = mthd.split("\n")
    for line in lines:
        if "//" in line and not bracket and not "://" in line:
            pairs[-1].append(line)
            if '\t' in line:
                indent_lvl = line.count('\t')
            else:
                indent_lvl = line.split("//")[0].count(' ')
            comment = True
            bracket = False
        elif comment:
            if '{' in line and not bracket:
                bracket = True
                pairs[-1].append(line)
            elif '}' in line:
                line_indent = -1
                if '\t' in line:
                    line_indent = line.count('\t')
                else:
                    line_indent = line.split("//")[0].count(' ')
                if indent_lvl == line_indent:
                    pairs[-1].append(line)
                if not bracket:
                    pairs.append([])
                    comment = False
                    bracket = False
            elif line.isspace() or line == '' and not bracket:
                pairs.append([])
                comment = False
            else:
                pairs[-1].append(line)
    
    # Convert pairs into proper format of (code snippet, inline comment) dataframe
    code_snippets   = []
    comments        = []
    for pair in pairs:
        if pair and len(pair) < 5:
            code    = []
            comment = []
            skip = False
            for line in pair:
                if "TODO" in line: break
                if "//" in line:
                    comment.append(line.replace('//', ''))
                else:
                    code.append(line)
            if len(code) > 1 and len(comment) > 0:
                        code_snippets.append('\n'.join(code))
                        comments.append('\n'.join(comment))

    pairs = pd.DataFrame(zip(code_snippets, comments), columns = ["mthd", "cmt"])
    return pairs


def add_inline(df: pd.DataFrame) -> pd.DataFrame:
    '''
    Helper function to go through all methods in a given dataframe and add all
    pairs of inline comments and corresponding code snippets

    :param df: the dataframe to retrieve and add all pairs of inline comments
    and corresponding code snippets to
    :returns: a new dataframe with the newly added pairs of inline comments and
    corresponding code snippets
    '''
    new_df = df[df['mthd'].str.contains("//")]
    all_pairs = []
    for mthd in tqdm(new_df.mthd.values):
        pairs = get_inline_pairs(mthd)
        all_pairs.append(pairs)

    df_pairs = pd.concat([pairs for pairs in all_pairs])
    return pd.concat([df, df_pairs])

df_trn = add_inline(df_trn)
df_val = add_inline(df_val)
df_tst = add_inline(df_tst)

len(df_trn), len(df_val), len(df_tst)


(4584, 150, 271)

We'll also remove pairs where the size of the code is smaller than the comment. This is because I found that in these cases the comments contain a bunch of extra information that the model won't have access to such as how the method is being used by other methods in the software system.

df_trn = df_trn[df_trn.apply(lambda row: len(row.mthd) > len(row.cmt), axis = 1)]
df_val = df_val[df_val.apply(lambda row: len(row.mthd) > len(row.cmt), axis = 1)]
df_tst = df_tst[df_tst.apply(lambda row: len(row.mthd) > len(row.cmt), axis = 1)]

len(df_trn), len(df_val), len(df_tst)
(3713, 111, 228)

Next, we'll remove any examples that have the special \ tag since these also tend to contain extra information that the model doesn't have a good hope of generating.</p> </div> </div> </div>

def has_code(cmt: str) -> bool:
    '''
    Determinine if the given comment contains the HTML <code> tag

    :param cmt: the comment to check whether it contains the HTML <code> tag
    :returns: whether or not the given comment contains the HTML <code> tag
    '''
    if '<code>' in cmt: return True
    else: return False

df_trn = df_trn[~df_trn['cmt'].apply(lambda x: has_code(x))]
df_val = df_val[~df_val['cmt'].apply(lambda x: has_code(x))]
df_tst = df_tst[~df_tst['cmt'].apply(lambda x: has_code(x))]

len(df_trn), len(df_val), len(df_tst)
(3580, 104, 221)

Lastly, we're gonna remove the JavaDoc parts of the comments other than the description since that is really all we care about. The other pieces of information can usually be autogenerated or may require external knowledge to document them.

def remove_jdocs(df: pd.DataFrame) -> pd.DataFrame:
    '''
    Remove the JavaDocs leaving only the description of the comment

    :param df: the pandas dataframe to remove the JavaDocs from
    :returns: a new pandas dataframe with the JavaDocs removed
    '''
    methods = []
    comments = []
    for i, row in tqdm(list(df.iterrows())):
        comment = row["cmt"]
        # Remove {} text in comments from https://stackoverflow.com/questions/14596884/remove-text-between-and-in-python/14598135
        comment = re.sub("([\{\[]).*?([\)\}])", '', comment)
        
        
        cleaned = []
        for line in comment.split('\n'):
            if "@" in line: break
            cleaned.append(line)
        comments.append('\n'.join(cleaned))
        methods.append(row["mthd"])
    new_df = pd.DataFrame(zip(methods, comments), columns = ["mthd", "cmt"])

    return new_df

df_trn = remove_jdocs(df_trn);
df_val = remove_jdocs(df_val);
df_tst = remove_jdocs(df_tst);


Almost there! In this step, we'll remove any HTML tags from the comments so the model doesn't have to also learn HTML. Bless those that do...

def clean_html(cmt: str) -> str:
    '''
    Remove any HTML tags from a given comment

    :param cmt: the comment to remove any HTML tags from
    :returns: the comment with any HTML tags removed
    '''
    result = re.sub(r"<.?span[^>]*>|<.?code[^>]*>|<.?p[^>]*>|<.?hr[^>]*>|<.?h[1-3][^>]*>|<.?a[^>]*>|<.?b[^>]*>|<.?blockquote[^>]*>|<.?del[^>]*>|<.?dd[^>]*>|<.?dl[^>]*>|<.?dt[^>]*>|<.?em[^>]*>|<.?i[^>]*>|<.?img[^>]*>|<.?kbd[^>]*>|<.?li[^>]*>|<.?ol[^>]*>|<.?pre[^>]*>|<.?s[^>]*>|<.?sup[^>]*>|<.?sub[^>]*>|<.?strong[^>]*>|<.?strike[^>]*>|<.?ul[^>]*>|<.?br[^>]*>", "", cmt)
    return result

df_trn.cmt = df_trn.cmt.apply(clean_html)
df_val.cmt = df_val.cmt.apply(clean_html)
df_tst.cmt = df_tst.cmt.apply(clean_html)

FINALLY!! We'll make everything lower case, remove extra whitespace, remove empty comments, and remove duplicates.

df_trn = df_trn.applymap(lambda x: ' '.join(x.split()).lower())
df_val = df_val.applymap(lambda x: ' '.join(x.split()).lower())
df_tst = df_tst.applymap(lambda x: ' '.join(x.split()).lower())

df_trn = df_trn[~(df_trn['cmt'] == '')]
df_val = df_val[~(df_val['cmt'] == '')]
df_tst = df_tst[~(df_tst['cmt'] == '')]

df_trn = df_trn[~df_trn['cmt'].duplicated()]
df_val = df_val[~df_val['cmt'].duplicated()]
df_tst = df_tst[~df_tst['cmt'].duplicated()]

len(df_trn), len(df_val), len(df_tst)
(3094, 94, 205)

Now let's see what the data looks like.

df_trn.head()
mthd cmt
0 public static void checkjavainternalaccess(ilo... prints warning to given if hazelcast is not pr...
1 public void marshall(scte20plusembeddeddestina... marshall the given parameter object.
2 @override public void prefetchtoken(final file... /* gets hadoop tokens for a user to run mapred...
3 @override public <y> singularattribute<x, y> g... /* (non-javadoc)
4 public void sync(boolean syncallsegments) { co... forces a disk flush on the commit log files th...

Data Exploring

As good Data Scientists, we will also explore our data to uncover any secrets. Data can be sneaky like that :).

import numpy as np

from collections import Counter
from statistics import mean, median, stdev
from transformers import AutoTokenizer

def get_counter(df: pd.DataFrame, tokenizer: AutoTokenizer, col: str) -> Counter:
    '''
    Get the counts for each token in a given pandas dataframe column

    :param df: the pandas dataframe to get the counts of tokens from
    :param tokenizer: the tokenizer to use for tokenizing the rows in the pandas
    dataframe
    :param col: the column to grab rows from when tokenizing
    :returns: the counts of each token in the given pandas dataframe
    column
    '''
    toks = []
    for i, row in df.iterrows():
        toks.extend(tokenizer.tokenize(row[col]))
            
    cnt = Counter()
    for tok in toks:
        cnt[tok] += 1  
    return cnt

tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
mthd_cnt = get_counter(df_trn, tokenizer, 'mthd')
cmt_cnt = get_counter(df_trn, tokenizer, 'cmt')
mthd_lens = df_trn.mthd.apply(lambda x: len(tokenizer.tokenize(x))).values
cmt_lens = df_trn.cmt.apply(lambda x: len(tokenizer.tokenize(x))).values
max_mthd_len = int(np.quantile(mthd_lens, 0.95))
max_cmt_len = int(np.quantile(cmt_lens, 0.95))




import matplotlib.pyplot as plt

def plot_counts(counts:Counter, top_k: Optional[int] = 30):
    '''
    Plot a bar chart of the most common tokens

    :param counts: the counts of each token
    :param top_k: the number of tokens to display in the plot
    '''
    labels, values = zip(*counts.most_common()[:top_k])

    indexes = np.arange(len(labels))
    width = 1
    plt.figure(num=None, figsize=(22, 4), dpi=60, facecolor='w', edgecolor='k')
    plt.bar(indexes, values, width)
    plt.xticks(indexes + width * 0.5, labels)
    plt.show()

Let's look at the most common tokens in our methods and comments.

plot_counts(mthd_cnt, top_k = 30)
plot_counts(cmt_cnt, top_k = 30)

def plot_hist(lens: List[int], n_bins: Optional[int] = 50):
    '''
    Plot a histogram of the given number of tokens in a column 

    :param lens: the number of tokens in a column
    :param n_bins: the number of bins to sort the number of tokens into
    '''
    n, bins, patches = plt.hist(lens, n_bins, facecolor='blue', alpha=0.9)
    plt.show()

Now, let's look at the distribution of method and comment lengths.

print(mean(mthd_lens), median(mthd_lens), stdev(mthd_lens))
plot_hist(mthd_lens)
print(mean(cmt_lens), median(cmt_lens), stdev(cmt_lens))
plot_hist(cmt_lens)
177 102.0 283.76574846164925
17 12.0 19.77371993328519

Using this new information on the length distribution, we can remove outliers by filter by lengths of methods that fall outside of 95th percentile (chosen for completely arbitrary reasons)!

def filter_len(
    row: pd.Series, tokenizer: AutoTokenizer, mthd_len: int, cmt_len: int
    ) -> bool:
    '''
    Determine if a given panda dataframe row has a method or comment that has
    more tokens than max length

    :param row: the row to check if it has a method or comment that is too long
    :param tokenizer: the tokenizer to tokenize a method or comment
    :param mthd_len: the max number of tokens a method can have
    :param cmt_len: the max number of tokens a comment can have
    :returns: whether or not the given row have a method or comment that have
    more tokens than a max length
    '''
    return len(tokenizer.tokenize(row.mthd)) < mthd_len and len(tokenizer.tokenize(row.cmt)) < cmt_len

df_trn = df_trn[df_trn.apply(
    lambda row: filter_len(
        row, tokenizer, max_mthd_len,
        max_cmt_len
    ), axis = 1
)]
df_val = df_val[df_val.apply(
    lambda row: filter_len(
        row, tokenizer, max_mthd_len,
        max_cmt_len
    ), axis = 1
)]
df_tst = df_tst[df_tst.apply(
    lambda row: filter_len(
        row, tokenizer, max_mthd_len,
        max_cmt_len
    ), axis = 1
)]

len(df_trn), len(df_val), len(df_tst)
(2809, 88, 193)
max_mthd_len, max_cmt_len
(559, 48)

We could do a lot more exploring of our data as the above exploration was the bare minimum. As an exercise, I suggest for you to explore the data on your own using whatever means necessary!

Training

Now that we have our data processed and in a format we like, let's go ahead and start training! To accomplish this we will be using code from the awesome CodeXGLUE repository. This repository is similar to the NLP equivalent GLUE benchmarks where a ton of awesome code related benchmarks are standardized and put into one place for the community to use! They have a ton of interesting ones and I highly suggest looking through their repo if you are interested in other code related tasks.

cd ./CodeXGLUE/Code-Text/code-to-text/code
/content/CodeXGLUE/Code-Text/code-to-text/code

Okay, I lied, sorry :(. One last processing step is required of our data, which is to just output the data into the structure that the awesome CodeXGLUE Code-Text benchmark expects.

import json

df_trn['code_tokens'] = df_trn.mthd.apply(lambda x: x.split())
df_trn['docstring_tokens'] = df_trn.cmt.apply(lambda x: x.split())
with open('java/train.jsonl','w') as f:
    for _, row in df_trn.iterrows():
        f.write(json.dumps(row.to_dict()) + '\n')

df_val['code_tokens'] = df_val.mthd.apply(lambda x: x.split())
df_val['docstring_tokens'] = df_val.cmt.apply(lambda x: x.split())
with open('java/valid.jsonl','w') as f:
    for _, row in df_val.iterrows():
        f.write(json.dumps(row.to_dict()) + '\n')

df_tst['code_tokens'] = df_tst.mthd.apply(lambda x: x.split())
df_tst['docstring_tokens'] = df_tst.cmt.apply(lambda x: x.split())
with open('java/test.jsonl','w') as f:
    for _, row in df_tst.iterrows():
        f.write(json.dumps(row.to_dict()) + '\n')
lang = 'java' # programming language
lr = 5e-5
batch_size = 8 # change depending on the GPU Colab gives you
beam_size = 10
source_length = 256
target_length = max_cmt_len
data_dir = '.'
output_dir = f'model/{lang}'
train_file = f'{data_dir}/{lang}/train.jsonl'
dev_file = f'{data_dir}/{lang}/valid.jsonl'
epochs = 10 
pretrained_model = 'microsoft/codebert-base'

! python run.py \
    --do_train \
    --do_eval \
    --do_lower_case \
    --model_type roberta \
    --model_name_or_path {pretrained_model} \
    --train_filename {train_file} \
    --dev_filename {dev_file} \
    --output_dir {output_dir} \
    --max_source_length {source_length} \
    --max_target_length {target_length} \
    --beam_size {beam_size} \
    --train_batch_size {batch_size} \
    --eval_batch_size {batch_size} \
    --learning_rate {lr} \
    --num_train_epochs {epochs}
2021-01-14 20:49:04.427229: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
01/14/2021 20:49:06 - INFO - __main__ -   Namespace(adam_epsilon=1e-08, beam_size=10, config_name='', dev_filename='./java/valid.jsonl', do_eval=True, do_lower_case=True, do_test=False, do_train=True, eval_batch_size=8, eval_steps=-1, gradient_accumulation_steps=1, learning_rate=5e-05, load_model_path=None, local_rank=-1, max_grad_norm=1.0, max_source_length=256, max_steps=-1, max_target_length=48, model_name_or_path='microsoft/codebert-base', model_type='roberta', no_cuda=False, num_train_epochs=10, output_dir='model/java', seed=42, test_filename=None, tokenizer_name='', train_batch_size=8, train_filename='./java/train.jsonl', train_steps=-1, warmup_steps=0, weight_decay=0.0)
01/14/2021 20:49:06 - WARNING - __main__ -   Process rank: -1, device: cuda, n_gpu: 1, distributed training: False
01/14/2021 20:49:06 - INFO - filelock -   Lock 140293701425752 acquired on /root/.cache/torch/transformers/08477dcecf305af90229876aa01e4b0f3594dc8c638985a72277f39ea7d8d0c3.7fb14267817b1d26bb44a57cd5aa2fc003c25e87b75ef77e9c55c4804675b4cf.lock
Downloading: 100% 499M/499M [00:06<00:00, 73.5MB/s]
01/14/2021 20:49:13 - INFO - filelock -   Lock 140293701425752 released on /root/.cache/torch/transformers/08477dcecf305af90229876aa01e4b0f3594dc8c638985a72277f39ea7d8d0c3.7fb14267817b1d26bb44a57cd5aa2fc003c25e87b75ef77e9c55c4804675b4cf.lock
01/14/2021 20:49:30 - INFO - __main__ -   *** Example ***
01/14/2021 20:49:30 - INFO - __main__ -   idx: 0
01/14/2021 20:49:30 - INFO - __main__ -   source_tokens: ['<s>', 'public', '_static', '_void', '_check', 'j', 'av', 'ain', 'ternal', 'access', '(', 'il', 'og', 'ger', '_logger', ')', '_{', '_if', '_(', 'log', 'ger', '_==', '_null', '_||', '_!', 'java', 'version', '.', 'is', 'at', 'le', 'ast', '(', 'java', 'version', '.', 'java', '_', '9', '))', '_{', '_//', '_older', '_java', '_versions', '_are', '_fine', '_with', '_the', '_reflection', '_return', ';', '_}', '_map', '<', 'string', ',', '_package', 'access', 'requ', 'irement', '[]', '>', '_requirements', '_=', '_new', '_tre', 'em', 'ap', '<', 'string', ',', '_package', 'access', 'requ', 'irement', '[]', '>', '();', '_requirements', '.', 'put', '("', 'java', '.', 'base', '",', '_new', '_package', 'access', 'requ', 'irement', '[]', '_{', '_create', 'requ', 'irement', '(', 'false', ',', '_"', 'j', 'dk', '.', 'internal', '.', 'ref', '"),', '_create', 'requ', 'irement', '(', 'true', ',', '_"', 'java', '.', 'lang', '"),', '_create', 'requ', 'irement', '(', 'true', ',', '_"', 'java', '.', 'n', 'io', '"),', '_create', 'requ', 'irement', '(', 'true', ',', '_"', 'sun', '.', 'n', 'io', '.', 'ch', '")', '_});', '_requirements', '.', 'put', '("', 'j', 'dk', '.', 'management', '",', '_get', 'j', 'dk', 'management', 'requ', 'irements', '());', '_requirements', '.', 'put', '("', 'java', '.', 'management', '",', '_new', '_package', 'access', 'requ', 'irement', '[]', '_{', '_create', 'requ', 'irement', '(', 'true', ',', '_"', 'sun', '.', 'management', '")', '_});', '_check', 'package', 'requ', 'irements', '(', 'log', 'ger', ',', '_requirements', ');', '_}', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   source_ids: 0 15110 25156 13842 1649 267 1469 1851 46378 28300 1640 718 2154 2403 37764 43 25522 114 36 12376 2403 45994 23796 45056 27785 43830 21747 4 354 415 459 1988 1640 43830 21747 4 43830 1215 466 35122 25522 21277 2530 46900 7952 32 2051 19 5 12456 671 131 35524 5456 41552 20951 6 3737 28300 42172 34074 48992 15698 3471 5457 92 6110 991 1115 41552 20951 6 3737 28300 42172 34074 48992 15698 47006 3471 4 9179 46469 43830 4 11070 1297 92 3737 28300 42172 34074 48992 25522 1045 42172 34074 1640 22303 6 22 267 43357 4 37559 4 13043 16844 1045 42172 34074 1640 29225 6 22 43830 4 32373 16844 1045 42172 34074 1640 29225 6 22 43830 4 282 1020 16844 1045 42172 34074 1640 29225 6 22 21381 4 282 1020 4 611 8070 47771 3471 4 9179 46469 267 43357 4 14668 1297 120 267 43357 14668 42172 48227 49291 3471 4 9179 46469 43830 4 14668 1297 92 3737 28300 42172 34074 48992 25522 1045 42172 34074 1640 29225 6 22 21381 4 14668 8070 47771 1649 46181 42172 48227 1640 12376 2403 6 3471 4397 35524 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   source_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   target_tokens: ['<s>', 'prints', '_warning', '_to', '_given', '_if', '_haz', 'el', 'cast', '_is', '_not', '_provided', '_a', '_sufficient', '_access', '_to', '_java', '_internal', '_packages', '_on', '_java', '_9', '_and', '_newer', '.', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   target_ids: 0 31553 2892 7 576 114 32468 523 5182 16 45 1286 10 7719 899 7 46900 3425 8368 15 46900 361 8 13964 4 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   target_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   *** Example ***
01/14/2021 20:49:30 - INFO - __main__ -   idx: 1
01/14/2021 20:49:30 - INFO - __main__ -   source_tokens: ['<s>', 'public', '_void', '_marsh', 'all', '(', 's', 'ct', 'e', '20', 'pl', 'use', 'mb', 'edd', 'edd', 'est', 'inations', 'ettings', '_s', 'ct', 'e', '20', 'pl', 'use', 'mb', 'edd', 'edd', 'est', 'inations', 'ettings', ',', '_protocol', 'm', 'arsh', 'all', 'er', '_protocol', 'm', 'arsh', 'all', 'er', ')', '_{', '_if', '_(', 's', 'ct', 'e', '20', 'pl', 'use', 'mb', 'edd', 'edd', 'est', 'inations', 'ettings', '_==', '_null', ')', '_{', '_throw', '_new', '_s', 'dk', 'client', 'ex', 'ception', '("', 'in', 'valid', '_argument', '_passed', '_to', '_marsh', 'all', '(', '...)', '");', '_}', '_try', '_{', '_}', '_catch', '_(', 'ex', 'ception', '_e', ')', '_{', '_throw', '_new', '_s', 'dk', 'client', 'ex', 'ception', '("', 'un', 'able', '_to', '_marsh', 'all', '_request', '_to', '_json', ':', '_"', '_+', '_e', '.', 'get', 'message', '(),', '_e', ');', '_}', '_}', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   source_ids: 0 15110 13842 16377 1250 1640 29 3894 242 844 2911 3698 6648 13093 13093 990 17808 48496 579 3894 242 844 2911 3698 6648 13093 13093 990 17808 48496 6 11883 119 14980 1250 254 11883 119 14980 1250 254 43 25522 114 36 29 3894 242 844 2911 3698 6648 13093 13093 990 17808 48496 45994 23796 43 25522 3211 92 579 43357 38557 3463 20900 46469 179 42679 4795 1595 7 16377 1250 1640 41137 45751 35524 860 25522 35524 2916 36 3463 20900 364 43 25522 3211 92 579 43357 38557 3463 20900 46469 879 868 7 16377 1250 2069 7 49133 35 22 2055 364 4 6460 44773 49196 364 4397 35524 35524 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   source_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   target_tokens: ['<s>', 'm', 'arsh', 'all', '_the', '_given', '_parameter', '_object', '.', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   target_ids: 0 119 14980 1250 5 576 43797 7626 4 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   target_mask: 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   *** Example ***
01/14/2021 20:49:30 - INFO - __main__ -   idx: 2
01/14/2021 20:49:30 - INFO - __main__ -   source_tokens: ['<s>', '@', 'over', 'ride', '_public', '_void', '_pref', 'etch', 'token', '(', 'final', '_file', '_token', 'file', ',', '_final', '_props', '_props', ',', '_final', '_logger', '_logger', ')', '_throws', '_had', 'oop', 'security', 'man', 'age', 'rex', 'ception', '_{', '_final', '_string', '_us', 'ert', 'op', 'roxy', '_=', '_props', '.', 'get', 'string', '(', 'job', 'properties', '.', 'user', '_', 'to', '_', 'proxy', ');', '_logger', '.', 'info', '("', 'getting', '_had', 'oop', '_tokens', '_based', '_on', '_props', '_for', '_"', '_+', '_us', 'ert', 'op', 'roxy', ');', '_dop', 'ref', 'etch', '(', 'token', 'file', ',', '_props', ',', '_logger', ',', '_us', 'ert', 'op', 'roxy', ');', '_}', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   source_ids: 0 1039 2137 23167 285 13842 33284 29094 46657 1640 6156 2870 19233 21710 6 507 26504 26504 6 507 37764 37764 43 6989 56 18042 15506 397 1580 19633 20900 25522 507 6755 201 2399 1517 46963 5457 26504 4 6460 20951 1640 30056 47276 4 12105 1215 560 1215 47315 4397 37764 4 23999 46469 31315 56 18042 22121 716 15 26504 13 22 2055 201 2399 1517 46963 4397 32331 13043 29094 1640 46657 21710 6 26504 6 37764 6 201 2399 1517 46963 4397 35524 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   source_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   target_tokens: ['<s>', '/*', '_gets', '_had', 'oop', '_tokens', '_for', '_a', '_user', '_to', '_run', '_map', 'red', '/', 'h', 'ive', '_jobs', '_on', '_a', '_secured', '_cluster', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   target_ids: 0 49051 1516 56 18042 22121 13 10 3018 7 422 5456 2050 73 298 2088 1315 15 10 5288 18016 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   target_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   *** Example ***
01/14/2021 20:49:30 - INFO - __main__ -   idx: 3
01/14/2021 20:49:30 - INFO - __main__ -   source_tokens: ['<s>', '@', 'over', 'ride', '_public', '_<', 'y', '>', '_singular', 'attribute', '<', 'x', ',', '_y', '>', '_get', 'decl', 'ared', 'id', '(', 'class', '<', 'y', '>', '_param', 'class', ')', '_{', '_if', '_(', 'id', 'attribute', '_!=', '_null', ')', '_{', '_if', '_(', 'id', 'attribute', '.', 'get', 'j', 'av', 'at', 'ype', '().', 'equ', 'als', '(', 'param', 'class', ')', '_&&', '_!', 'is', 'id', 'class', ')', '_{', '_return', '_(', 'sing', 'ular', 'attribute', '<', 'x', ',', '_y', '>)', '_id', 'attribute', ';', '_}', '_}', '_on', 'error', '();', '_return', '_null', ';', '_}', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   source_ids: 0 1039 2137 23167 285 28696 219 15698 23429 49202 41552 1178 6 1423 15698 120 32639 6537 808 1640 4684 41552 219 15698 40206 4684 43 25522 114 36 808 49202 49333 23796 43 25522 114 36 808 49202 4 6460 267 1469 415 37356 49123 8198 1536 1640 46669 4684 43 48200 27785 354 808 4684 43 25522 671 36 26058 8244 49202 41552 1178 6 1423 49798 13561 49202 131 35524 35524 15 44223 47006 671 23796 131 35524 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   source_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   target_tokens: ['<s>', '/*', '_(', 'non', '-', 'j', 'av', 'ad', 'oc', ')', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   target_ids: 0 49051 36 13424 12 267 1469 625 1975 43 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   target_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   *** Example ***
01/14/2021 20:49:30 - INFO - __main__ -   idx: 4
01/14/2021 20:49:30 - INFO - __main__ -   source_tokens: ['<s>', 'public', '_void', '_sync', '(', 'bo', 'olean', '_syn', 'call', 'se', 'gments', ')', '_{', '_commit', 'log', 'se', 'gment', '_current', '_=', '_alloc', 'ator', '.', 'all', 'ocating', 'from', '();', '_for', '_(', 'commit', 'log', 'se', 'gment', '_segment', '_:', '_alloc', 'ator', '.', 'get', 'act', 'ives', 'eg', 'ments', '())', '_{', '_if', '_(!', 'sync', 'all', 'se', 'gments', '_&&', '_segment', '.', 'id', '_>', '_current', '.', 'id', ')', '_return', ';', '_segment', '.', 'sync', '();', '_}', '_}', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   source_ids: 0 15110 13842 22785 1640 3983 48547 17796 16395 1090 30237 43 25522 6225 12376 1090 10757 595 5457 42793 2630 4 1250 18106 7761 47006 13 36 42721 12376 1090 10757 2835 4832 42793 2630 4 6460 7257 3699 3733 2963 49338 25522 114 48209 45176 1250 1090 30237 48200 2835 4 808 8061 595 4 808 43 671 131 2835 4 45176 47006 35524 35524 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   source_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:30 - INFO - __main__ -   target_tokens: ['<s>', 'forces', '_a', '_disk', '_flush', '_on', '_the', '_commit', '_log', '_files', '_that', '_need', '_it', '.', '_blocking', '.', '</s>']
01/14/2021 20:49:30 - INFO - __main__ -   target_ids: 0 34532 10 21675 24841 15 5 6225 7425 6773 14 240 24 4 8890 4 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
01/14/2021 20:49:30 - INFO - __main__ -   target_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
01/14/2021 20:49:33 - INFO - __main__ -   ***** Running training *****
01/14/2021 20:49:33 - INFO - __main__ -     Num examples = 2809
01/14/2021 20:49:33 - INFO - __main__ -     Batch size = 8
01/14/2021 20:49:33 - INFO - __main__ -     Num epoch = 10
epoch 0 loss 6.8534: 100% 352/352 [02:53<00:00,  2.03it/s]
01/14/2021 20:52:27 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 20:52:27 - INFO - __main__ -     Num examples = 88
01/14/2021 20:52:27 - INFO - __main__ -     Batch size = 8
01/14/2021 20:52:29 - INFO - __main__ -     eval_ppl = 420.66683
01/14/2021 20:52:29 - INFO - __main__ -     global_step = 353
01/14/2021 20:52:29 - INFO - __main__ -     train_loss = 6.8534
01/14/2021 20:52:29 - INFO - __main__ -     ********************
01/14/2021 20:52:31 - INFO - __main__ -     Best ppl:420.66683
01/14/2021 20:52:31 - INFO - __main__ -     ********************
Total: 88
01/14/2021 20:52:58 - INFO - __main__ -     bleu-4 = 9.79 
01/14/2021 20:52:58 - INFO - __main__ -     ********************
01/14/2021 20:52:58 - INFO - __main__ -     Best bleu:9.79
01/14/2021 20:52:58 - INFO - __main__ -     ********************
epoch 1 loss 5.2249: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 20:55:58 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 20:55:58 - INFO - __main__ -     Num examples = 88
01/14/2021 20:55:58 - INFO - __main__ -     Batch size = 8
01/14/2021 20:56:00 - INFO - __main__ -     eval_ppl = 223.30135
01/14/2021 20:56:00 - INFO - __main__ -     global_step = 705
01/14/2021 20:56:00 - INFO - __main__ -     train_loss = 5.2249
01/14/2021 20:56:00 - INFO - __main__ -     ********************
01/14/2021 20:56:02 - INFO - __main__ -     Best ppl:223.30135
01/14/2021 20:56:02 - INFO - __main__ -     ********************
Total: 88
01/14/2021 20:56:30 - INFO - __main__ -     bleu-4 = 10.3 
01/14/2021 20:56:30 - INFO - __main__ -     ********************
01/14/2021 20:56:30 - INFO - __main__ -     Best bleu:10.3
01/14/2021 20:56:30 - INFO - __main__ -     ********************
epoch 2 loss 4.4676: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 20:59:31 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 20:59:31 - INFO - __main__ -     Num examples = 88
01/14/2021 20:59:31 - INFO - __main__ -     Batch size = 8
01/14/2021 20:59:32 - INFO - __main__ -     eval_ppl = 167.43889
01/14/2021 20:59:32 - INFO - __main__ -     global_step = 1057
01/14/2021 20:59:32 - INFO - __main__ -     train_loss = 4.4676
01/14/2021 20:59:32 - INFO - __main__ -     ********************
01/14/2021 20:59:35 - INFO - __main__ -     Best ppl:167.43889
01/14/2021 20:59:35 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:00:05 - INFO - __main__ -     bleu-4 = 10.68 
01/14/2021 21:00:05 - INFO - __main__ -     ********************
01/14/2021 21:00:05 - INFO - __main__ -     Best bleu:10.68
01/14/2021 21:00:05 - INFO - __main__ -     ********************
epoch 3 loss 3.8263: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:03:05 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:03:05 - INFO - __main__ -     Num examples = 88
01/14/2021 21:03:05 - INFO - __main__ -     Batch size = 8
01/14/2021 21:03:07 - INFO - __main__ -     eval_ppl = 160.25635
01/14/2021 21:03:07 - INFO - __main__ -     global_step = 1409
01/14/2021 21:03:07 - INFO - __main__ -     train_loss = 3.8263
01/14/2021 21:03:07 - INFO - __main__ -     ********************
01/14/2021 21:03:10 - INFO - __main__ -     Best ppl:160.25635
01/14/2021 21:03:10 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:03:38 - INFO - __main__ -     bleu-4 = 11.04 
01/14/2021 21:03:38 - INFO - __main__ -     ********************
01/14/2021 21:03:38 - INFO - __main__ -     Best bleu:11.04
01/14/2021 21:03:38 - INFO - __main__ -     ********************
epoch 4 loss 3.2797: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:06:38 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:06:38 - INFO - __main__ -     Num examples = 88
01/14/2021 21:06:38 - INFO - __main__ -     Batch size = 8
01/14/2021 21:06:40 - INFO - __main__ -     eval_ppl = 152.19858
01/14/2021 21:06:40 - INFO - __main__ -     global_step = 1761
01/14/2021 21:06:40 - INFO - __main__ -     train_loss = 3.2797
01/14/2021 21:06:40 - INFO - __main__ -     ********************
01/14/2021 21:06:42 - INFO - __main__ -     Best ppl:152.19858
01/14/2021 21:06:42 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:07:14 - INFO - __main__ -     bleu-4 = 10.36 
01/14/2021 21:07:14 - INFO - __main__ -     ********************
epoch 5 loss 2.8204: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:10:12 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:10:12 - INFO - __main__ -     Num examples = 88
01/14/2021 21:10:12 - INFO - __main__ -     Batch size = 8
01/14/2021 21:10:13 - INFO - __main__ -     eval_ppl = 150.95443
01/14/2021 21:10:13 - INFO - __main__ -     global_step = 2113
01/14/2021 21:10:13 - INFO - __main__ -     train_loss = 2.8204
01/14/2021 21:10:13 - INFO - __main__ -     ********************
01/14/2021 21:10:16 - INFO - __main__ -     Best ppl:150.95443
01/14/2021 21:10:16 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:10:45 - INFO - __main__ -     bleu-4 = 11.57 
01/14/2021 21:10:45 - INFO - __main__ -     ********************
01/14/2021 21:10:45 - INFO - __main__ -     Best bleu:11.57
01/14/2021 21:10:45 - INFO - __main__ -     ********************
epoch 6 loss 2.4442: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:13:46 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:13:46 - INFO - __main__ -     Num examples = 88
01/14/2021 21:13:46 - INFO - __main__ -     Batch size = 8
01/14/2021 21:13:47 - INFO - __main__ -     eval_ppl = 156.69898
01/14/2021 21:13:47 - INFO - __main__ -     global_step = 2465
01/14/2021 21:13:47 - INFO - __main__ -     train_loss = 2.4442
01/14/2021 21:13:47 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:14:17 - INFO - __main__ -     bleu-4 = 10.65 
01/14/2021 21:14:17 - INFO - __main__ -     ********************
epoch 7 loss 2.1565: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:17:15 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:17:15 - INFO - __main__ -     Num examples = 88
01/14/2021 21:17:15 - INFO - __main__ -     Batch size = 8
01/14/2021 21:17:16 - INFO - __main__ -     eval_ppl = 163.34726
01/14/2021 21:17:16 - INFO - __main__ -     global_step = 2817
01/14/2021 21:17:16 - INFO - __main__ -     train_loss = 2.1565
01/14/2021 21:17:16 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:17:50 - INFO - __main__ -     bleu-4 = 10.56 
01/14/2021 21:17:50 - INFO - __main__ -     ********************
epoch 8 loss 1.9398: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:20:47 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:20:47 - INFO - __main__ -     Num examples = 88
01/14/2021 21:20:47 - INFO - __main__ -     Batch size = 8
01/14/2021 21:20:49 - INFO - __main__ -     eval_ppl = 166.41823
01/14/2021 21:20:49 - INFO - __main__ -     global_step = 3169
01/14/2021 21:20:49 - INFO - __main__ -     train_loss = 1.9398
01/14/2021 21:20:49 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:21:26 - INFO - __main__ -     bleu-4 = 10.74 
01/14/2021 21:21:26 - INFO - __main__ -     ********************
epoch 9 loss 1.7877: 100% 352/352 [02:57<00:00,  1.98it/s]
01/14/2021 21:24:24 - INFO - __main__ -   
***** Running evaluation *****
01/14/2021 21:24:24 - INFO - __main__ -     Num examples = 88
01/14/2021 21:24:24 - INFO - __main__ -     Batch size = 8
01/14/2021 21:24:25 - INFO - __main__ -     eval_ppl = 169.37057
01/14/2021 21:24:25 - INFO - __main__ -     global_step = 3521
01/14/2021 21:24:25 - INFO - __main__ -     train_loss = 1.7877
01/14/2021 21:24:25 - INFO - __main__ -     ********************
Total: 88
01/14/2021 21:24:59 - INFO - __main__ -     bleu-4 = 10.28 
01/14/2021 21:24:59 - INFO - __main__ -     ********************

Yay! Our model has finished baking and we can now see how well it turned out by evaluating it!

batch_size=64
dev_file=f"{data_dir}/{lang}/valid.jsonl"
test_file=f"{data_dir}/{lang}/test.jsonl"
test_model=f"{output_dir}/checkpoint-best-bleu/pytorch_model.bin" #checkpoint for test

! python run.py \
    --do_test \
    --model_type roberta \
    --model_name_or_path microsoft/codebert-base \
    --load_model_path {test_model} \
    --dev_filename {dev_file} \
    --test_filename {test_file} \
    --output_dir {output_dir} \
    --max_source_length {source_length} \
    --max_target_length {target_length} \
    --beam_size {beam_size} \
    --eval_batch_size {batch_size}
2021-01-14 21:25:04.498200: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
01/14/2021 21:25:07 - INFO - __main__ -   Namespace(adam_epsilon=1e-08, beam_size=10, config_name='', dev_filename='./java/valid.jsonl', do_eval=False, do_lower_case=False, do_test=True, do_train=False, eval_batch_size=64, eval_steps=-1, gradient_accumulation_steps=1, learning_rate=5e-05, load_model_path='model/java/checkpoint-best-bleu/pytorch_model.bin', local_rank=-1, max_grad_norm=1.0, max_source_length=256, max_steps=-1, max_target_length=48, model_name_or_path='microsoft/codebert-base', model_type='roberta', no_cuda=False, num_train_epochs=3, output_dir='model/java', seed=42, test_filename='./java/test.jsonl', tokenizer_name='', train_batch_size=8, train_filename=None, train_steps=-1, warmup_steps=0, weight_decay=0.0)
01/14/2021 21:25:07 - WARNING - __main__ -   Process rank: -1, device: cuda, n_gpu: 1, distributed training: False
01/14/2021 21:25:23 - INFO - __main__ -   reload model from model/java/checkpoint-best-bleu/pytorch_model.bin
01/14/2021 21:25:48 - INFO - __main__ -   Test file: ./java/valid.jsonl
100% 2/2 [00:26<00:00, 13.34s/it]
Total: 88
01/14/2021 21:26:15 - INFO - __main__ -     bleu-4 = 11.57 
01/14/2021 21:26:15 - INFO - __main__ -     ********************
01/14/2021 21:26:15 - INFO - __main__ -   Test file: ./java/test.jsonl
100% 4/4 [00:55<00:00, 13.95s/it]
Total: 193
01/14/2021 21:27:11 - INFO - __main__ -     bleu-4 = 9.74 
01/14/2021 21:27:11 - INFO - __main__ -     ********************

Let's now load up our model and take it for a spin!

import torch

import torch.nn as nn

from model import Seq2Seq
from transformers import RobertaConfig, RobertaModel

config = RobertaConfig.from_pretrained(pretrained_model)
encoder = RobertaModel.from_pretrained(pretrained_model, config = config)    
decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(encoder = encoder,decoder = decoder,config=config,
                beam_size=beam_size,max_length=target_length,
                sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id)
model.load_state_dict(torch.load(Path(output_dir)/"checkpoint-last/pytorch_model.bin"))
model.to('cuda')
Seq2Seq(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (2): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (3): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (4): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (5): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (6): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (7): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (8): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (9): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (10): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (11): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): RobertaIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): RobertaOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): RobertaPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
      (2): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
      (3): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
      (4): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
      (5): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (lm_head): Linear(in_features=768, out_features=50265, bias=False)
  (lsm): LogSoftmax()
)
idx = 0
TEXT_TO_SUMMARIZE = df_val.mthd.values[idx]
print('Code:', TEXT_TO_SUMMARIZE)
print('Original Comment:', df_val.cmt.values[idx])
Code: public static byte[] decode(final string s) { int delta = s.endswith("==") ? 2 : s.endswith("=") ? 1 : 0; byte[] buffer = new byte[s.length() * bytes_per_unencoded_block / bytes_per_encoded_block - delta]; int mask = 0xff; int pos = 0; for (int i = 0; i < s.length(); i += bytes_per_encoded_block) { int c0 = decode_table[s.charat(i)]; int c1 = decode_table[s.charat(i + 1)]; buffer[pos++] = (byte) (((c0 << 2) | (c1 >> 4)) & mask); if (pos >= buffer.length) { return buffer; } int c2 = decode_table[s.charat(i + 2)]; buffer[pos++] = (byte) (((c1 << 4) | (c2 >> 2)) & mask); if (pos >= buffer.length) { return buffer; } int c3 = decode_table[s.charat(i + 3)]; buffer[pos++] = (byte) (((c2 << 6) | c3) & mask); } return buffer; }
Original Comment: decodes the given base64-encoded string.

from run import convert_examples_to_features, Example

class Args:
    max_source_length = source_length
    max_target_length = target_length

args = Args()

def get_preds(df: pd.DataFrame):
    ps = []
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        examples = [
            Example(idx, source = row.mthd, target = row.cmt)
        ]
        eval_features = convert_examples_to_features(
            examples, tokenizer, args, stage='test'
        )
        source_ids = torch.tensor(eval_features[0].source_ids, dtype = torch.long).unsqueeze(0).to('cuda')
        source_mask = torch.tensor(eval_features[0].source_mask, dtype = torch.long).unsqueeze(0).to('cuda')

        with torch.no_grad():
            preds = model(source_ids = source_ids, source_mask = source_mask)  
            for pred in preds:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                ps.append(text)
    
    return ps
df_val = df_val.reset_index()
preds = get_preds(df_val.head(10))
for idx, row in df_val.head(10).iterrows():
    print('Code:', row.mthd)
    print('Original Comment:', row.cmt)
    print('Generated Comment:', preds[idx])
    print('='*40)
Code: public static byte[] decode(final string s) { int delta = s.endswith("==") ? 2 : s.endswith("=") ? 1 : 0; byte[] buffer = new byte[s.length() * bytes_per_unencoded_block / bytes_per_encoded_block - delta]; int mask = 0xff; int pos = 0; for (int i = 0; i < s.length(); i += bytes_per_encoded_block) { int c0 = decode_table[s.charat(i)]; int c1 = decode_table[s.charat(i + 1)]; buffer[pos++] = (byte) (((c0 << 2) | (c1 >> 4)) & mask); if (pos >= buffer.length) { return buffer; } int c2 = decode_table[s.charat(i + 2)]; buffer[pos++] = (byte) (((c1 << 4) | (c2 >> 2)) & mask); if (pos >= buffer.length) { return buffer; } int c3 = decode_table[s.charat(i + 3)]; buffer[pos++] = (byte) (((c2 << 6) | c3) & mask); } return buffer; }
Original Comment: decodes the given base64-encoded string.
Generated Comment: decode encode a string representation of string
========================================
Code: private void extractapklib( artifact apklibartifact ) throws mojoexecutionexception { getunpackedlibhelper().extractapklib( apklibartifact ); // copy the assets to the the combinedassets folder. // add the apklib source and resource to the compile. // nb apklib sources are added to compilesourceroot because we may need to compile against them. // this means the apklib classes will be compiled into target/classes and packaged with this build. copyfolder( getunpackedlibassetsfolder( apklibartifact ), combinedassets ); final file apklibsourcefolder = getunpackedapklibsourcefolder( apklibartifact ); final list<string> resourceexclusions = arrays.aslist( "**/*.java", "**/*.aidl" ); projecthelper.addresource( project, apklibsourcefolder.getabsolutepath(), null, resourceexclusions ); project.addcompilesourceroot( apklibsourcefolder.getabsolutepath() ); }
Original Comment: extracts apklib and adds the assets and apklib sources and resources to the build.
Generated Comment: extracts the cp libraries from the given source library.
========================================
Code: static <t> t[] copy(object[] source, int from, int to, t[] arrayoftype) { t[] result = newarray(arrayoftype, to - from); system.arraycopy(source, from, result, 0, to - from); return result; }
Original Comment: equivalent to arrays.copyofrange(source, from, to, arrayoftype.getclass()).
Generated Comment: creates a new object from the array.
========================================
Code: private static runtimedelegate finddelegate() { runtimedelegate result=null; try { result=createruntimedelegatefromspi(); if(result==null) { result=createruntimedelegatefromconfigurationfile(); } if(result==null) { string delegateclassname = system.getproperty(application_engine_spi_property); if(delegateclassname!=null) { result=createruntimedelegateforclassname(delegateclassname); } } } catch (exception ex) { logger.warn("could not find application engine",ex); } return result; }
Original Comment: obtain an instance using the method described in }.
Generated Comment: /* package
========================================
Code: public static string getcategory(string eventsrcname) { if (eventsrcname == null) { return null; } int end = eventsrcname.lastindexof('.'); eventsrcname = eventsrcname.substring(0, end); if (checkstyle_package.equals(eventsrcname)) { return "misc"; } else if (!eventsrcname.startswith(checkstyle_package)) { return "extension"; } return eventsrcname.substring(eventsrcname.lastindexof('.') + 1); }
Original Comment: get the rule category from an audit event source name.
Generated Comment: returns the contents of the event name.
========================================
Code: private collection<artifact> getserverdependencies(final string servertype, final expressionevaluator expressionevaluator) throws componentconfigurationexception { try { final mavenproject project = (mavenproject) expressionevaluator.evaluate("${project}"); final string localrepo = (string) expressionevaluator.evaluate("${settings.localrepository}"); final artifactrepository localrepository = repositorysystem.createlocalrepository(new file(localrepo)); final repositoryrequest repositoryrequest = new defaultrepositoryrequest(); repositoryrequest.setremoterepositories(project.getremoteartifactrepositories()); repositoryrequest.setlocalrepository(localrepository); final artifactresolutionrequest request = new artifactresolutionrequest(repositoryrequest); request.setartifact(getserverartifact(servertype)); request.setresolvetransitively(true); final artifactresolutionresult result = repositorysystem.resolve(request); if (result.issuccess()) { return result.getartifacts(); } boolean first = true; final stringbuilder builder = new stringbuilder("cannot resolve dependencies: ["); for (final artifact artifact : result.getmissingartifacts()) { if (!first) { builder.append(','); } else { first = false; } builder.append(artifact.getgroupid()); builder.append(':'); builder.append(artifact.getartifactid()); builder.append(':'); builder.append(artifact.getversion()); } builder.append("]"); throw new componentconfigurationexception(builder.tostring()); } catch (final expressionevaluationexception e) { throw new componentconfigurationexception("error evaluating expression", e); } catch (final invalidrepositoryexception e) { throw new componentconfigurationexception("error resolving local repository", e); } }
Original Comment: resolve the ldap server type artifact and its dependencies.
Generated Comment: gets the repositories from the repository.
========================================
Code: private void frame4() { long currenttime = system.currenttimemillis(); // xxx: lots of dummy value // record trade information in trade table. // insert into trade (t_id, t_dts, t_st_id, t_tt_id, t_is_cash, // t_s_symb, t_qty, t_bid_price, t_ca_id, t_exec_name, t_trade_price, // t_chrg, t_comm, t_tax, t_lifo) values (...) string sql = string.format("insert into trade (t_id, t_dts, t_st_id, t_tt_id, " + "t_is_cash, t_s_symb, t_qty, t_bid_price, t_ca_id, t_exec_name, " + "t_trade_price, t_chrg, t_comm, t_tax, t_lifo) values (%d, %d, '%s', " + "'%s', %d, '%s', %d, %f, %d, '%s', %f, %f, %f, %f, %d)", paramhelper.gettradeid(), currenttime, statusid, paramhelper.gettradetypeid(), 1, paramhelper.getsymbol(), paramhelper.gettradeqty(), marketprice, paramhelper.getacctid(), "exec_name", paramhelper.gettradeprice(), 0.0, 0.0, 0.0, 1); executeupdate(sql); // todo: implement this (not in the simplified version) // record pending trade information in trade_request table // if this trade is a limit trade // insert into trade_request (tr_t_id, tr_tt_id, tr_s_symb, tr_qty, // tr_bid_price, tr_b_id) values (...) // record trade information in trade_history table // insert into trade_history (th_t_id, th_dts, th_st_id) values (...) sql = string.format("insert into trade_history (th_t_id, th_dts, th_st_id) values " + "(%d, %d, '%s')", paramhelper.gettradeid(), currenttime, statusid); executeupdate(sql); }
Original Comment: record the trade request by making all related updates
Generated Comment: this method is used to create the database.
========================================
Code: protected string getquery() { final stringbuilder ret = new stringbuilder(); try { final string clazzname; if (efapssystemconfiguration.get().containsattributevalue("org.efaps.kernel.index.querybuilder")) { clazzname = efapssystemconfiguration.get().getattributevalue("org.efaps.kernel.index.querybuilder"); } else { clazzname = "org.efaps.esjp.admin.index.lucencequerybuilder"; } final class<?> clazz = class.forname(clazzname, false, efapsclassloader.getinstance()); final object obj = clazz.newinstance(); final method method = clazz.getmethod("getquery4dimvalues", string.class, list.class, list.class); final object newquery = method.invoke(obj, getcurrentquery(), getincluded(), getexcluded()); ret.append(newquery); } catch (final efapsexception | classnotfoundexception | instantiationexception | illegalaccessexception | nosuchmethodexception | securityexception | illegalargumentexception | invocationtargetexception e) { indexsearch.log.error("catched", e); ret.append(getcurrentquery()); } return ret.tostring(); }
Original Comment: gets the query.
Generated Comment: get the query instance.
========================================
Code: private languagedata findlanguage(final string locale) { for (final languagedata languagedata : languagedatadao.getall()) { if (languagedata.getlanguagecode().equalsignorecase(locale)) { return languagedata; } } return null; }
Original Comment: find language.
Generated Comment: gets the specified locale.
========================================
Code: private standardintrospectionresponse callstandardintrospection(string parameters) { if (parameters == null) { // authlete returns different error codes for null and an empty string. // 'null' is regarded as a caller's error. an empty string is regarded // as a client application's error. parameters = ""; } // create a request for authlete's /api/auth/introspection/standard api. standardintrospectionrequest request = new standardintrospectionrequest() .setparameters(parameters); try { // call authlete's /api/auth/introspection/standard api. return mapi.standardintrospection(request); } catch (authleteapiexception e) { // the api call failed. throw apifailure("/api/auth/introspection/standard", e); } }
Original Comment: call authlete's api.
Generated Comment: returns a set of authentication object.
========================================

The model seems to be doing a good job, but if you play with it some more you'll realize it is mostly taking the name of the method and using that to guide the comment. This makes sense, but it probably isn't learning much more than this association, at least with this small model. Let's explore it a bit more by looking at all the examples in the validation set it is failing the most on.

def get_preds_losses(df: pd.DataFrame):
    ps = []
    losses = []
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        examples = [
            Example(idx, source = row.mthd, target = row.cmt)
        ]
        eval_features = convert_examples_to_features(
            examples, tokenizer, args, stage='test'
        )
        source_ids = torch.tensor([f.source_ids for f in eval_features], dtype = torch.long).to('cuda')
        source_mask = torch.tensor([f.source_mask for f in eval_features], dtype = torch.long).to('cuda')
        target_ids = torch.tensor([f.target_ids for f in eval_features], dtype = torch.long).to('cuda')
        target_mask = torch.tensor([f.target_mask for f in eval_features], dtype = torch.long).to('cuda')

        with torch.no_grad():
            _, loss, _ = model(
                source_ids = source_ids, source_mask = source_mask,
                target_ids = target_ids, target_mask = target_mask
            )
            preds = model(source_ids = source_ids, source_mask = source_mask)  
            for pred in preds:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                ps.append(text)
                losses.append(loss.item())
    
    return ps, losses
df_head = df_val.copy()
ps, losses = get_preds_losses(df_head)
df_head['pred'] = ps
df_head['loss'] = losses
df_sorted_losses = df_head.sort_values('loss', ascending = False)

for _, row in df_sorted_losses.head(10).iterrows():
    print('Code:', row.mthd)
    print('Original Comment:', row.cmt)
    print('Generated Comment:', row.pred)
    print(row.loss)
    print('='*40)
Code: private collection<artifact> getserverdependencies(final string servertype, final expressionevaluator expressionevaluator) throws componentconfigurationexception { try { final mavenproject project = (mavenproject) expressionevaluator.evaluate("${project}"); final string localrepo = (string) expressionevaluator.evaluate("${settings.localrepository}"); final artifactrepository localrepository = repositorysystem.createlocalrepository(new file(localrepo)); final repositoryrequest repositoryrequest = new defaultrepositoryrequest(); repositoryrequest.setremoterepositories(project.getremoteartifactrepositories()); repositoryrequest.setlocalrepository(localrepository); final artifactresolutionrequest request = new artifactresolutionrequest(repositoryrequest); request.setartifact(getserverartifact(servertype)); request.setresolvetransitively(true); final artifactresolutionresult result = repositorysystem.resolve(request); if (result.issuccess()) { return result.getartifacts(); } boolean first = true; final stringbuilder builder = new stringbuilder("cannot resolve dependencies: ["); for (final artifact artifact : result.getmissingartifacts()) { if (!first) { builder.append(','); } else { first = false; } builder.append(artifact.getgroupid()); builder.append(':'); builder.append(artifact.getartifactid()); builder.append(':'); builder.append(artifact.getversion()); } builder.append("]"); throw new componentconfigurationexception(builder.tostring()); } catch (final expressionevaluationexception e) { throw new componentconfigurationexception("error evaluating expression", e); } catch (final invalidrepositoryexception e) { throw new componentconfigurationexception("error resolving local repository", e); } }
Original Comment: resolve the ldap server type artifact and its dependencies.
Generated Comment: gets the repository from the repository.
24.875783920288086
========================================
Code: public static byte[] decode(final string s) { int delta = s.endswith("==") ? 2 : s.endswith("=") ? 1 : 0; byte[] buffer = new byte[s.length() * bytes_per_unencoded_block / bytes_per_encoded_block - delta]; int mask = 0xff; int pos = 0; for (int i = 0; i < s.length(); i += bytes_per_encoded_block) { int c0 = decode_table[s.charat(i)]; int c1 = decode_table[s.charat(i + 1)]; buffer[pos++] = (byte) (((c0 << 2) | (c1 >> 4)) & mask); if (pos >= buffer.length) { return buffer; } int c2 = decode_table[s.charat(i + 2)]; buffer[pos++] = (byte) (((c1 << 4) | (c2 >> 2)) & mask); if (pos >= buffer.length) { return buffer; } int c3 = decode_table[s.charat(i + 3)]; buffer[pos++] = (byte) (((c2 << 6) | c3) & mask); } return buffer; }
Original Comment: decodes the given base64-encoded string.
Generated Comment: encodes a string value from a string.
24.304515838623047
========================================
Code: @override public void init(configurationvalueprovider... configurationvalueproviders) { if (configurationvalueproviders != null) { for (configurationproperty property : getcontainer().properties.values()) { property.init(configurationvalueproviders); } } }
Original Comment: override default values for properties with the given configurationproviders.
Generated Comment: configures all the options in the given configuration.
24.276317596435547
========================================
Code: private static boolean validatepart(string part, boolean isfinalpart) { // these tests could be collapsed into one big boolean expression, but // they have been left as independent tests for clarity. if (part.length() < 1 || part.length() > max_domain_part_length) { return false; } /* * gwt claims to support java.lang.character's char-classification methods, but it actually only * works for ascii. so for now, assume any non-ascii characters are valid. the only place this * seems to be documented is here: * http://osdir.com/ml/googlewebtoolkitcontributors/2010-03/msg00178.html * * <p>ascii characters in the part are expected to be valid per rfc 1035, with underscore also * being allowed due to widespread practice. */ string asciichars = charmatcher.ascii().retainfrom(part); if (!part_char_matcher.matchesallof(asciichars)) { return false; } // no initial or final dashes or underscores. if (dash_matcher.matches(part.charat(0)) || dash_matcher.matches(part.charat(part.length() - 1))) { return false; } /* * note that we allow (in contravention of a strict interpretation of the relevant rfcs) domain * parts other than the last may begin with a digit (for example, "3com.com"). it's important to * disallow an initial digit in the last part; it's the only thing that stops an ipv4 numeric * address like 127.0.0.1 from looking like a valid domain name. */ if (isfinalpart && charmatcher.digit().matches(part.charat(0))) { return false; } return true; }
Original Comment: helper method for }. validates that one part of a domain name is valid.
Generated Comment: parses a string representation of the given string.
24.256574630737305
========================================
Code: private void extractapklib( artifact apklibartifact ) throws mojoexecutionexception { getunpackedlibhelper().extractapklib( apklibartifact ); // copy the assets to the the combinedassets folder. // add the apklib source and resource to the compile. // nb apklib sources are added to compilesourceroot because we may need to compile against them. // this means the apklib classes will be compiled into target/classes and packaged with this build. copyfolder( getunpackedlibassetsfolder( apklibartifact ), combinedassets ); final file apklibsourcefolder = getunpackedapklibsourcefolder( apklibartifact ); final list<string> resourceexclusions = arrays.aslist( "**/*.java", "**/*.aidl" ); projecthelper.addresource( project, apklibsourcefolder.getabsolutepath(), null, resourceexclusions ); project.addcompilesourceroot( apklibsourcefolder.getabsolutepath() ); }
Original Comment: extracts apklib and adds the assets and apklib sources and resources to the build.
Generated Comment: extracts compiled from the cp compiler.
23.989707946777344
========================================
Code: public void adddefaultheader(final string name, final string value) { validate.notempty(name, "header name cannot be empty"); validate.notnull(value, "header value cannot be null, use an empty string instead"); this.checkconfigurable(); this.defaultheaders.put(name, value); }
Original Comment: adds a default header to be added to every stub http response.
Generated Comment: adds the headers to the headers.
23.846609115600586
========================================
Code: public static schema getschema(final file xsd, final errorhandler errorhandler) throws saxexception { // create a new instance for an xsd-aware schemafactory final schemafactory schemafactory = schemafactory .newinstance(http_www_w3_org_2001_xml_schema); // set the errorhandler implementation. schemafactory.seterrorhandler(errorhandler); // get the custom xsd schema that describes // the required format for my xml files. return schemafactory.newschema(xsd); }
Original Comment: gets the schema.
Generated Comment: creates a xml object from the given namespace.
23.77509880065918
========================================
Code: @override protected formatwriter createwriter(final outputstream outputstream, final formatlogger logger) { try { return new dsmlformatwriter(outputstream); } catch (final ioexception e) { logger.logerror("could not create and intialise the dsml writer", e); } return null; }
Original Comment: create the ldap writer that will dump ldap entries to a dsml file.
Generated Comment: creates a new writer as a xml file.
23.688125610351562
========================================
Code: @override public volatileimage createcompatiblevolatileimage(int width, int height, imagecapabilities caps, int transparency) throws awtexception { if (img == null) { img = new bufferedimage(1, 1, bufferedimage.type_int_argb); gc = img.creategraphics().getdeviceconfiguration(); } return gc.createcompatiblevolatileimage(width, height, caps, transparency); }
Original Comment: returns a volatile image. this method is a workaround for a classcastexception that occurs on macosx when exporting a swing ui that uses the nimbus look and feel to svg.
Generated Comment: create a new image from a new image.
23.60519790649414
========================================
Code: private static void printstacktrace(printstream out, throwable err) { out.println(err.getclass().getname() + ": " + err.getmessage()); for (stacktraceelement ste : err.getstacktrace()) { out.println("\tat " + ste.tostring()); } if (err.getcause() != null) { out.print("caused by: "); printstacktrace(out, err.getcause()); } }
Original Comment: print a complete stack trace. this differs from throwable.printstacktrace() in that it always prints all of the trace.
Generated Comment: print out a message.
23.529924392700195
========================================

What's Next?

If you'd like to see how you can integrate this code comment summarizer model into the popular VSCode IDE, check out my video that goes over just that!

</div>