Hi there, in this post you'll learn how to finetune the a RoBERT based model that's been trained on code data to automatically generate comments for code!
We will be focusing on the Java programming language, but you can apply the same techniques in this post for any programming language that interests you. Additionally, you'll see how to incorporate this code commenter into a VSCode extension so that you can generate comments for code snippets you highlight:
(Insert GIF of tool working)
As always, we'll start with a bit of background of the data and model we are using, but feel free to skip if you want to get straight to the awesomeness ;). Alright, let's GO!
Background
Data
We will be using the awesome CodeSearchNet Challenge dataset, which contains millions of pairs of methods and their docstrings for a large variety of programming languages. The dataset was initially constructed for evaluating how well different approaches perform at searching for code. However, we can easily repurpose it for us and lucky for us, the awesome authors did an awesome job collecting, documenting, and cleaning the data.
We'll be performing a bit more cleaning and formatting of the data as well as adding some more examples. These examples won't be method/docstring pairs, but code snippet/inline comment pairs. This allows our model to generate comments for arbitrary code snippets that a developer may want to document instead of just generating the docstring of a method.
CodeBERT
The pretrained model we will be finetuning comes from the awesome paper from Microsoft's research division aptly named CodeBERT: A Pre-Trained Model for Programming and Natural Languages. This model also used the CodeSearchNet challenge dataset, but instead of using it to generate comments it used to teach a RoBERTa based model to represent code and natural language in a useful way. This practice of eaching these large language models to represent text in a useful way is common practice now since these representations have been shown to be helpful in finetuning these models on other tasks. The CodeBERT paper showed these representations are helpful by finetuning them on the programming task of code search and comment generation, exactly what we will be doing! The difference between their comment generation task and ours is that we will do a bit more preprocessing and our model will be able to generate inline comments of code snippets and not just method level comments.
So, how does CodeBERT learn these representations? It combines two different training objectives that's been shown to be useful for natural language. The Masked Language Modeling objective (MLM), which is from the original BERT paper, and Replaced Token Detection (RTD) objective, which is from the ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators paper. The MLM objective is where we randomly mask out parts of the text that we feed into the model and ask the model to predict those masked out pieces. The RTD objective is where random tokens in the text are replaced and the model has to determine which of these tokens are replaced. However, to make it harder for the model, these replaced tokens attempt to be plausible alternatives and not just random words. The CodeBERT model actually used a n-gram based model to generate these alternatives where as the ELECTRA paper used a small BERT based model.
(From ELECTRA Paper)
Instead of using only natural language to apply these training objectives to, CodeBERT used code and docstrings. This allowed the CodeBERT model to learn a useful representation of code that could be used for other tasks.
Alright with that quick background knowledge down, lets get into actually finetuning our model!
! nvidia-smi
Thu Jan 14 20:43:12 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04 Driver Version: 418.67 CUDA Version: 10.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 42C P8 11W / 70W | 0MiB / 15079MiB | 0% Default |
| | | ERR! |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Data
First we'll install the necessary packages and download our data!
# Download and install the necessary dependencies! pip install -q torch==1.4.0 -f https://download.pytorch.org/whl/cu101/torch_stable.html
! pip install -q transformers==3.5.0 fast-trees
! git clone -q https://github.com/microsoft/CodeXGLUE.git
# Download the CodeSearchNet Challenge dataset for the Java programming language! wget -q https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip
! unzip -qq java.zip
|████████████████████████████████| 753.4MB 21kB/s
ERROR: torchvision 0.8.1+cu101 has requirement torch==1.7.0, but you'll have torch 1.4.0 which is incompatible.
|████████████████████████████████| 1.3MB 12.6MB/s
|████████████████████████████████| 890kB 51.7MB/s
|████████████████████████████████| 2.9MB 50.3MB/s
|████████████████████████████████| 1.1MB 61.3MB/s
|████████████████████████████████| 112kB 64.8MB/s
|████████████████████████████████| 163kB 60.4MB/s
|████████████████████████████████| 71kB 11.8MB/s
Building wheel for sacremoses (setup.py) ... done
Building wheel for tree-sitter (setup.py) ... done
Next let's read in our data and since these models take a long time to train, we will only select a subset of the data.
importpandasaspdfrompathlibimportPathfromtypingimportList,Optional# Code from CodeSearchNetChallenge: https://github.com/github/CodeSearchNet/blob/master/notebooks/ExploreData.ipynbdefjsonl_list_to_dataframe(file_list,columns=['code','docstring']):"""Load a list of jsonl.gz files into a pandas DataFrame."""returnpd.concat([pd.read_json(f,orient='records',compression='gzip',lines=True)[columns]forfinfile_list],sort=False)defget_dfs(path:Path)->List[pd.DataFrame]:"""Grabs the different data splits and converts them into dataframes"""dfs=[]forsplitin["train","valid","test"]:files=sorted((path/split).glob("**/*.gz"))df=jsonl_list_to_dataframe(files).rename(columns={'code':'mthd','docstring':'cmt'})dfs.append(df)returndfspath=Path('.')df_trn,df_val,df_tst=get_dfs(path/"java/final/jsonl")sample=0.01df_trn=df_trn.sample(frac=sample)df_val=df_val.sample(frac=sample)df_tst=df_tst.sample(frac=sample)len(df_trn),len(df_val),len(df_tst)
(4545, 153, 269)
Let's see how the data looks. As shown, we have the data in a good format with one column all of the methods (input into the model) and the other all of the comments (output of the model).
df_trn.head()
mthd
cmt
5360
@Override\n public GetLexiconResult getLexi...
<p>\nReturns the content of the specified pron...
9365
public static void checkJavaInternalAccess(ILo...
Prints warning to given {@link ILogger} if Haz...
10145
private IAtom createAtom(Element element) {\n ...
Create a new atom for the provided symbol. The...
9008
public void marshall(Scte20PlusEmbeddedDestina...
Marshall the given parameter object.
24498
@Override\n public void prefetchToken(final F...
/*\nGets hadoop tokens for a user to run mapre...
Data Cleaning
Now, that we have the data, let's clean it! First, we'll remove any non-ascii characters to simplify the problem so that the model only has to think about generating English comments.
# From https://stackoverflow.com/a/27084708/5768407defis_ascii(s):''' Determines if the given string contains only ascii characters :param s: the string to check :returns: whether or not the given string contains only ascii characters '''try:s.encode(encoding='utf-8').decode('ascii')exceptUnicodeDecodeError:returnFalseelse:returnTruedf_trn=df_trn[df_trn['mthd'].apply(lambdax:is_ascii(x))]df_val=df_val[df_val['mthd'].apply(lambdax:is_ascii(x))]df_tst=df_tst[df_tst['mthd'].apply(lambdax:is_ascii(x))]df_trn=df_trn[df_trn['cmt'].apply(lambdax:is_ascii(x))]df_val=df_val[df_val['cmt'].apply(lambdax:is_ascii(x))]df_tst=df_tst[df_tst['cmt'].apply(lambdax:is_ascii(x))]len(df_trn),len(df_val),len(df_tst)
(4402, 141, 264)
Next, we'll remove any outdated comments by checking to see if the JavaDoc's parameter list is different from the method's parameter list. This also will remove pairs where the docstring doesn't actually document the parameters, which probably means the pairs are poor quality (you should always properly document your code :) ).
importrefromfast_trees.coreimportFastParserparser=FastParser('java')defget_cmt_params(cmt:str)->List[str]:''' Grabs the parameter identifier names from a JavaDoc comment :param cmt: the comment to extract the parameter identifier names from :returns: an array of the parameter identifier names found in the given comment '''params=re.findall('@param+\s+\w+',cmt)param_names=[]forparaminparams:param_names.append(param.split()[1])returnparam_namesdefis_outdated(mthd:str,cmt:str,parser:FastParser)->bool:''' Determines if a given method and comment are outdated by checking if the method's parameter identifier names match the comment's :param mthd: the method to compare against its corresponding comment :param cmt: the comment to compare against its corresponding method :param parser: parser for easily getting the parameter identifier names from a given method :returns: wheather or not a given comment is outdated compared to its corresponding method '''try:mthd_params=parser.get_params(mthd)except:returnFalsecmt_params=get_cmt_params(cmt)returnmthd_params!=cmt_paramsdf_trn=df_trn[~df_trn.apply(lambdax:is_outdated(x.mthd,x.cmt,parser),axis=1)]df_val=df_val[~df_val.apply(lambdax:is_outdated(x.mthd,x.cmt,parser),axis=1)]df_tst=df_tst[~df_tst.apply(lambdax:is_outdated(x.mthd,x.cmt,parser),axis=1)]len(df_trn),len(df_val),len(df_tst)
Downloading repo https://github.com/tree-sitter/tree-sitter-java to /usr/local/lib/python3.6/dist-packages/fast_trees/tree-sitter-java.
(4402, 141, 264)
Now we'll add in the additional pairs of code snippets/inline comments.
P.S. One thing to note with adding these pairs is that the inline comments will appear twice in the datasets. The first in the method where the inline comment came from and the second in the target for the code snippet. This is only a problem for the training set since it allows for the model to cheat by simply remembering the inline comment from the example method it came from. However, in my testing, I found this to not be an issue and the model seems to still work well despite this problem. Just thought ya should know :).
fromtqdm.autoimporttqdmdefget_inline_pairs(mthd):''' Get all pairs of inline comments and corresponding code snippets :param mthd: the method to retrieve the pairs of comments and corresponding code snippets from :returns: all pairs of comments and corresponding code snippets '''pairs=[[]]comment=Falsebracket=Falseindent_lvl=-1lines=mthd.split("\n")forlineinlines:if"//"inlineandnotbracketandnot"://"inline:pairs[-1].append(line)if'\t'inline:indent_lvl=line.count('\t')else:indent_lvl=line.split("//")[0].count(' ')comment=Truebracket=Falseelifcomment:if'{'inlineandnotbracket:bracket=Truepairs[-1].append(line)elif'}'inline:line_indent=-1if'\t'inline:line_indent=line.count('\t')else:line_indent=line.split("//")[0].count(' ')ifindent_lvl==line_indent:pairs[-1].append(line)ifnotbracket:pairs.append([])comment=Falsebracket=Falseelifline.isspace()orline==''andnotbracket:pairs.append([])comment=Falseelse:pairs[-1].append(line)# Convert pairs into proper format of (code snippet, inline comment) dataframecode_snippets=[]comments=[]forpairinpairs:ifpairandlen(pair)<5:code=[]comment=[]skip=Falseforlineinpair:if"TODO"inline:breakif"//"inline:comment.append(line.replace('//',''))else:code.append(line)iflen(code)>1andlen(comment)>0:code_snippets.append('\n'.join(code))comments.append('\n'.join(comment))pairs=pd.DataFrame(zip(code_snippets,comments),columns=["mthd","cmt"])returnpairsdefadd_inline(df:pd.DataFrame)->pd.DataFrame:''' Helper function to go through all methods in a given dataframe and add all pairs of inline comments and corresponding code snippets :param df: the dataframe to retrieve and add all pairs of inline comments and corresponding code snippets to :returns: a new dataframe with the newly added pairs of inline comments and corresponding code snippets '''new_df=df[df['mthd'].str.contains("//")]all_pairs=[]formthdintqdm(new_df.mthd.values):pairs=get_inline_pairs(mthd)all_pairs.append(pairs)df_pairs=pd.concat([pairsforpairsinall_pairs])returnpd.concat([df,df_pairs])df_trn=add_inline(df_trn)df_val=add_inline(df_val)df_tst=add_inline(df_tst)len(df_trn),len(df_val),len(df_tst)
(4584, 150, 271)
We'll also remove pairs where the size of the code is smaller than the comment. This is because I found that in these cases the comments contain a bunch of extra information that the model won't have access to such as how the method is being used by other methods in the software system.
Next, we'll remove any examples that have the special \ tag since these also tend to contain extra information that the model doesn't have a good hope of generating.</p>
</div>
</div>
</div>
defhas_code(cmt:str)->bool:''' Determinine if the given comment contains the HTML <code> tag :param cmt: the comment to check whether it contains the HTML <code> tag :returns: whether or not the given comment contains the HTML <code> tag '''if'<code>'incmt:returnTrueelse:returnFalsedf_trn=df_trn[~df_trn['cmt'].apply(lambdax:has_code(x))]df_val=df_val[~df_val['cmt'].apply(lambdax:has_code(x))]df_tst=df_tst[~df_tst['cmt'].apply(lambdax:has_code(x))]len(df_trn),len(df_val),len(df_tst)
(3580, 104, 221)
Lastly, we're gonna remove the JavaDoc parts of the comments other than the description since that is really all we care about. The other pieces of information can usually be autogenerated or may require external knowledge to document them.
defremove_jdocs(df:pd.DataFrame)->pd.DataFrame:''' Remove the JavaDocs leaving only the description of the comment :param df: the pandas dataframe to remove the JavaDocs from :returns: a new pandas dataframe with the JavaDocs removed '''methods=[]comments=[]fori,rowintqdm(list(df.iterrows())):comment=row["cmt"]# Remove {} text in comments from https://stackoverflow.com/questions/14596884/remove-text-between-and-in-python/14598135comment=re.sub("([\{\[]).*?([\)\}])",'',comment)cleaned=[]forlineincomment.split('\n'):if"@"inline:breakcleaned.append(line)comments.append('\n'.join(cleaned))methods.append(row["mthd"])new_df=pd.DataFrame(zip(methods,comments),columns=["mthd","cmt"])returnnew_dfdf_trn=remove_jdocs(df_trn);df_val=remove_jdocs(df_val);df_tst=remove_jdocs(df_tst);
Almost there! In this step, we'll remove any HTML tags from the comments so the model doesn't have to also learn HTML. Bless those that do...
defclean_html(cmt:str)->str:''' Remove any HTML tags from a given comment :param cmt: the comment to remove any HTML tags from :returns: the comment with any HTML tags removed '''result=re.sub(r"<.?span[^>]*>|<.?code[^>]*>|<.?p[^>]*>|<.?hr[^>]*>|<.?h[1-3][^>]*>|<.?a[^>]*>|<.?b[^>]*>|<.?blockquote[^>]*>|<.?del[^>]*>|<.?dd[^>]*>|<.?dl[^>]*>|<.?dt[^>]*>|<.?em[^>]*>|<.?i[^>]*>|<.?img[^>]*>|<.?kbd[^>]*>|<.?li[^>]*>|<.?ol[^>]*>|<.?pre[^>]*>|<.?s[^>]*>|<.?sup[^>]*>|<.?sub[^>]*>|<.?strong[^>]*>|<.?strike[^>]*>|<.?ul[^>]*>|<.?br[^>]*>","",cmt)returnresultdf_trn.cmt=df_trn.cmt.apply(clean_html)df_val.cmt=df_val.cmt.apply(clean_html)df_tst.cmt=df_tst.cmt.apply(clean_html)
FINALLY!! We'll make everything lower case, remove extra whitespace, remove empty comments, and remove duplicates.
As good Data Scientists, we will also explore our data to uncover any secrets. Data can be sneaky like that :).
importnumpyasnpfromcollectionsimportCounterfromstatisticsimportmean,median,stdevfromtransformersimportAutoTokenizerdefget_counter(df:pd.DataFrame,tokenizer:AutoTokenizer,col:str)->Counter:''' Get the counts for each token in a given pandas dataframe column :param df: the pandas dataframe to get the counts of tokens from :param tokenizer: the tokenizer to use for tokenizing the rows in the pandas dataframe :param col: the column to grab rows from when tokenizing :returns: the counts of each token in the given pandas dataframe column '''toks=[]fori,rowindf.iterrows():toks.extend(tokenizer.tokenize(row[col]))cnt=Counter()fortokintoks:cnt[tok]+=1returncnttokenizer=AutoTokenizer.from_pretrained('microsoft/codebert-base')mthd_cnt=get_counter(df_trn,tokenizer,'mthd')cmt_cnt=get_counter(df_trn,tokenizer,'cmt')mthd_lens=df_trn.mthd.apply(lambdax:len(tokenizer.tokenize(x))).valuescmt_lens=df_trn.cmt.apply(lambdax:len(tokenizer.tokenize(x))).valuesmax_mthd_len=int(np.quantile(mthd_lens,0.95))max_cmt_len=int(np.quantile(cmt_lens,0.95))
importmatplotlib.pyplotaspltdefplot_counts(counts:Counter,top_k:Optional[int]=30):''' Plot a bar chart of the most common tokens :param counts: the counts of each token :param top_k: the number of tokens to display in the plot '''labels,values=zip(*counts.most_common()[:top_k])indexes=np.arange(len(labels))width=1plt.figure(num=None,figsize=(22,4),dpi=60,facecolor='w',edgecolor='k')plt.bar(indexes,values,width)plt.xticks(indexes+width*0.5,labels)plt.show()
Let's look at the most common tokens in our methods and comments.
defplot_hist(lens:List[int],n_bins:Optional[int]=50):''' Plot a histogram of the given number of tokens in a column :param lens: the number of tokens in a column :param n_bins: the number of bins to sort the number of tokens into '''n,bins,patches=plt.hist(lens,n_bins,facecolor='blue',alpha=0.9)plt.show()
Now, let's look at the distribution of method and comment lengths.
Using this new information on the length distribution, we can remove outliers by filter by lengths of methods that fall outside of 95th percentile (chosen for completely arbitrary reasons)!
deffilter_len(row:pd.Series,tokenizer:AutoTokenizer,mthd_len:int,cmt_len:int)->bool:''' Determine if a given panda dataframe row has a method or comment that has more tokens than max length :param row: the row to check if it has a method or comment that is too long :param tokenizer: the tokenizer to tokenize a method or comment :param mthd_len: the max number of tokens a method can have :param cmt_len: the max number of tokens a comment can have :returns: whether or not the given row have a method or comment that have more tokens than a max length '''returnlen(tokenizer.tokenize(row.mthd))<mthd_lenandlen(tokenizer.tokenize(row.cmt))<cmt_lendf_trn=df_trn[df_trn.apply(lambdarow:filter_len(row,tokenizer,max_mthd_len,max_cmt_len),axis=1)]df_val=df_val[df_val.apply(lambdarow:filter_len(row,tokenizer,max_mthd_len,max_cmt_len),axis=1)]df_tst=df_tst[df_tst.apply(lambdarow:filter_len(row,tokenizer,max_mthd_len,max_cmt_len),axis=1)]len(df_trn),len(df_val),len(df_tst)
(2809, 88, 193)
max_mthd_len,max_cmt_len
(559, 48)
We could do a lot more exploring of our data as the above exploration was the bare minimum. As an exercise, I suggest for you to explore the data on your own using whatever means necessary!
Training
Now that we have our data processed and in a format we like, let's go ahead and start training! To accomplish this we will be using code from the awesome CodeXGLUE repository. This repository is similar to the NLP equivalent GLUE benchmarks where a ton of awesome code related benchmarks are standardized and put into one place for the community to use! They have a ton of interesting ones and I highly suggest looking through their repo if you are interested in other code related tasks.
cd./CodeXGLUE/Code-Text/code-to-text/code
/content/CodeXGLUE/Code-Text/code-to-text/code
Okay, I lied, sorry :(. One last processing step is required of our data, which is to just output the data into the structure that the awesome CodeXGLUE Code-Text benchmark expects.
Code: public static byte[] decode(final string s) { int delta = s.endswith("==") ? 2 : s.endswith("=") ? 1 : 0; byte[] buffer = new byte[s.length() * bytes_per_unencoded_block / bytes_per_encoded_block - delta]; int mask = 0xff; int pos = 0; for (int i = 0; i < s.length(); i += bytes_per_encoded_block) { int c0 = decode_table[s.charat(i)]; int c1 = decode_table[s.charat(i + 1)]; buffer[pos++] = (byte) (((c0 << 2) | (c1 >> 4)) & mask); if (pos >= buffer.length) { return buffer; } int c2 = decode_table[s.charat(i + 2)]; buffer[pos++] = (byte) (((c1 << 4) | (c2 >> 2)) & mask); if (pos >= buffer.length) { return buffer; } int c3 = decode_table[s.charat(i + 3)]; buffer[pos++] = (byte) (((c2 << 6) | c3) & mask); } return buffer; }
Original Comment: decodes the given base64-encoded string.
Generated Comment: decode encode a string representation of string
========================================
Code: private void extractapklib( artifact apklibartifact ) throws mojoexecutionexception { getunpackedlibhelper().extractapklib( apklibartifact ); // copy the assets to the the combinedassets folder. // add the apklib source and resource to the compile. // nb apklib sources are added to compilesourceroot because we may need to compile against them. // this means the apklib classes will be compiled into target/classes and packaged with this build. copyfolder( getunpackedlibassetsfolder( apklibartifact ), combinedassets ); final file apklibsourcefolder = getunpackedapklibsourcefolder( apklibartifact ); final list<string> resourceexclusions = arrays.aslist( "**/*.java", "**/*.aidl" ); projecthelper.addresource( project, apklibsourcefolder.getabsolutepath(), null, resourceexclusions ); project.addcompilesourceroot( apklibsourcefolder.getabsolutepath() ); }
Original Comment: extracts apklib and adds the assets and apklib sources and resources to the build.
Generated Comment: extracts the cp libraries from the given source library.
========================================
Code: static <t> t[] copy(object[] source, int from, int to, t[] arrayoftype) { t[] result = newarray(arrayoftype, to - from); system.arraycopy(source, from, result, 0, to - from); return result; }
Original Comment: equivalent to arrays.copyofrange(source, from, to, arrayoftype.getclass()).
Generated Comment: creates a new object from the array.
========================================
Code: private static runtimedelegate finddelegate() { runtimedelegate result=null; try { result=createruntimedelegatefromspi(); if(result==null) { result=createruntimedelegatefromconfigurationfile(); } if(result==null) { string delegateclassname = system.getproperty(application_engine_spi_property); if(delegateclassname!=null) { result=createruntimedelegateforclassname(delegateclassname); } } } catch (exception ex) { logger.warn("could not find application engine",ex); } return result; }
Original Comment: obtain an instance using the method described in }.
Generated Comment: /* package
========================================
Code: public static string getcategory(string eventsrcname) { if (eventsrcname == null) { return null; } int end = eventsrcname.lastindexof('.'); eventsrcname = eventsrcname.substring(0, end); if (checkstyle_package.equals(eventsrcname)) { return "misc"; } else if (!eventsrcname.startswith(checkstyle_package)) { return "extension"; } return eventsrcname.substring(eventsrcname.lastindexof('.') + 1); }
Original Comment: get the rule category from an audit event source name.
Generated Comment: returns the contents of the event name.
========================================
Code: private collection<artifact> getserverdependencies(final string servertype, final expressionevaluator expressionevaluator) throws componentconfigurationexception { try { final mavenproject project = (mavenproject) expressionevaluator.evaluate("${project}"); final string localrepo = (string) expressionevaluator.evaluate("${settings.localrepository}"); final artifactrepository localrepository = repositorysystem.createlocalrepository(new file(localrepo)); final repositoryrequest repositoryrequest = new defaultrepositoryrequest(); repositoryrequest.setremoterepositories(project.getremoteartifactrepositories()); repositoryrequest.setlocalrepository(localrepository); final artifactresolutionrequest request = new artifactresolutionrequest(repositoryrequest); request.setartifact(getserverartifact(servertype)); request.setresolvetransitively(true); final artifactresolutionresult result = repositorysystem.resolve(request); if (result.issuccess()) { return result.getartifacts(); } boolean first = true; final stringbuilder builder = new stringbuilder("cannot resolve dependencies: ["); for (final artifact artifact : result.getmissingartifacts()) { if (!first) { builder.append(','); } else { first = false; } builder.append(artifact.getgroupid()); builder.append(':'); builder.append(artifact.getartifactid()); builder.append(':'); builder.append(artifact.getversion()); } builder.append("]"); throw new componentconfigurationexception(builder.tostring()); } catch (final expressionevaluationexception e) { throw new componentconfigurationexception("error evaluating expression", e); } catch (final invalidrepositoryexception e) { throw new componentconfigurationexception("error resolving local repository", e); } }
Original Comment: resolve the ldap server type artifact and its dependencies.
Generated Comment: gets the repositories from the repository.
========================================
Code: private void frame4() { long currenttime = system.currenttimemillis(); // xxx: lots of dummy value // record trade information in trade table. // insert into trade (t_id, t_dts, t_st_id, t_tt_id, t_is_cash, // t_s_symb, t_qty, t_bid_price, t_ca_id, t_exec_name, t_trade_price, // t_chrg, t_comm, t_tax, t_lifo) values (...) string sql = string.format("insert into trade (t_id, t_dts, t_st_id, t_tt_id, " + "t_is_cash, t_s_symb, t_qty, t_bid_price, t_ca_id, t_exec_name, " + "t_trade_price, t_chrg, t_comm, t_tax, t_lifo) values (%d, %d, '%s', " + "'%s', %d, '%s', %d, %f, %d, '%s', %f, %f, %f, %f, %d)", paramhelper.gettradeid(), currenttime, statusid, paramhelper.gettradetypeid(), 1, paramhelper.getsymbol(), paramhelper.gettradeqty(), marketprice, paramhelper.getacctid(), "exec_name", paramhelper.gettradeprice(), 0.0, 0.0, 0.0, 1); executeupdate(sql); // todo: implement this (not in the simplified version) // record pending trade information in trade_request table // if this trade is a limit trade // insert into trade_request (tr_t_id, tr_tt_id, tr_s_symb, tr_qty, // tr_bid_price, tr_b_id) values (...) // record trade information in trade_history table // insert into trade_history (th_t_id, th_dts, th_st_id) values (...) sql = string.format("insert into trade_history (th_t_id, th_dts, th_st_id) values " + "(%d, %d, '%s')", paramhelper.gettradeid(), currenttime, statusid); executeupdate(sql); }
Original Comment: record the trade request by making all related updates
Generated Comment: this method is used to create the database.
========================================
Code: protected string getquery() { final stringbuilder ret = new stringbuilder(); try { final string clazzname; if (efapssystemconfiguration.get().containsattributevalue("org.efaps.kernel.index.querybuilder")) { clazzname = efapssystemconfiguration.get().getattributevalue("org.efaps.kernel.index.querybuilder"); } else { clazzname = "org.efaps.esjp.admin.index.lucencequerybuilder"; } final class<?> clazz = class.forname(clazzname, false, efapsclassloader.getinstance()); final object obj = clazz.newinstance(); final method method = clazz.getmethod("getquery4dimvalues", string.class, list.class, list.class); final object newquery = method.invoke(obj, getcurrentquery(), getincluded(), getexcluded()); ret.append(newquery); } catch (final efapsexception | classnotfoundexception | instantiationexception | illegalaccessexception | nosuchmethodexception | securityexception | illegalargumentexception | invocationtargetexception e) { indexsearch.log.error("catched", e); ret.append(getcurrentquery()); } return ret.tostring(); }
Original Comment: gets the query.
Generated Comment: get the query instance.
========================================
Code: private languagedata findlanguage(final string locale) { for (final languagedata languagedata : languagedatadao.getall()) { if (languagedata.getlanguagecode().equalsignorecase(locale)) { return languagedata; } } return null; }
Original Comment: find language.
Generated Comment: gets the specified locale.
========================================
Code: private standardintrospectionresponse callstandardintrospection(string parameters) { if (parameters == null) { // authlete returns different error codes for null and an empty string. // 'null' is regarded as a caller's error. an empty string is regarded // as a client application's error. parameters = ""; } // create a request for authlete's /api/auth/introspection/standard api. standardintrospectionrequest request = new standardintrospectionrequest() .setparameters(parameters); try { // call authlete's /api/auth/introspection/standard api. return mapi.standardintrospection(request); } catch (authleteapiexception e) { // the api call failed. throw apifailure("/api/auth/introspection/standard", e); } }
Original Comment: call authlete's api.
Generated Comment: returns a set of authentication object.
========================================
The model seems to be doing a good job, but if you play with it some more you'll realize it is mostly taking the name of the method and using that to guide the comment. This makes sense, but it probably isn't learning much more than this association, at least with this small model. Let's explore it a bit more by looking at all the examples in the validation set it is failing the most on.
Code: private collection<artifact> getserverdependencies(final string servertype, final expressionevaluator expressionevaluator) throws componentconfigurationexception { try { final mavenproject project = (mavenproject) expressionevaluator.evaluate("${project}"); final string localrepo = (string) expressionevaluator.evaluate("${settings.localrepository}"); final artifactrepository localrepository = repositorysystem.createlocalrepository(new file(localrepo)); final repositoryrequest repositoryrequest = new defaultrepositoryrequest(); repositoryrequest.setremoterepositories(project.getremoteartifactrepositories()); repositoryrequest.setlocalrepository(localrepository); final artifactresolutionrequest request = new artifactresolutionrequest(repositoryrequest); request.setartifact(getserverartifact(servertype)); request.setresolvetransitively(true); final artifactresolutionresult result = repositorysystem.resolve(request); if (result.issuccess()) { return result.getartifacts(); } boolean first = true; final stringbuilder builder = new stringbuilder("cannot resolve dependencies: ["); for (final artifact artifact : result.getmissingartifacts()) { if (!first) { builder.append(','); } else { first = false; } builder.append(artifact.getgroupid()); builder.append(':'); builder.append(artifact.getartifactid()); builder.append(':'); builder.append(artifact.getversion()); } builder.append("]"); throw new componentconfigurationexception(builder.tostring()); } catch (final expressionevaluationexception e) { throw new componentconfigurationexception("error evaluating expression", e); } catch (final invalidrepositoryexception e) { throw new componentconfigurationexception("error resolving local repository", e); } }
Original Comment: resolve the ldap server type artifact and its dependencies.
Generated Comment: gets the repository from the repository.
24.875783920288086
========================================
Code: public static byte[] decode(final string s) { int delta = s.endswith("==") ? 2 : s.endswith("=") ? 1 : 0; byte[] buffer = new byte[s.length() * bytes_per_unencoded_block / bytes_per_encoded_block - delta]; int mask = 0xff; int pos = 0; for (int i = 0; i < s.length(); i += bytes_per_encoded_block) { int c0 = decode_table[s.charat(i)]; int c1 = decode_table[s.charat(i + 1)]; buffer[pos++] = (byte) (((c0 << 2) | (c1 >> 4)) & mask); if (pos >= buffer.length) { return buffer; } int c2 = decode_table[s.charat(i + 2)]; buffer[pos++] = (byte) (((c1 << 4) | (c2 >> 2)) & mask); if (pos >= buffer.length) { return buffer; } int c3 = decode_table[s.charat(i + 3)]; buffer[pos++] = (byte) (((c2 << 6) | c3) & mask); } return buffer; }
Original Comment: decodes the given base64-encoded string.
Generated Comment: encodes a string value from a string.
24.304515838623047
========================================
Code: @override public void init(configurationvalueprovider... configurationvalueproviders) { if (configurationvalueproviders != null) { for (configurationproperty property : getcontainer().properties.values()) { property.init(configurationvalueproviders); } } }
Original Comment: override default values for properties with the given configurationproviders.
Generated Comment: configures all the options in the given configuration.
24.276317596435547
========================================
Code: private static boolean validatepart(string part, boolean isfinalpart) { // these tests could be collapsed into one big boolean expression, but // they have been left as independent tests for clarity. if (part.length() < 1 || part.length() > max_domain_part_length) { return false; } /* * gwt claims to support java.lang.character's char-classification methods, but it actually only * works for ascii. so for now, assume any non-ascii characters are valid. the only place this * seems to be documented is here: * http://osdir.com/ml/googlewebtoolkitcontributors/2010-03/msg00178.html * * <p>ascii characters in the part are expected to be valid per rfc 1035, with underscore also * being allowed due to widespread practice. */ string asciichars = charmatcher.ascii().retainfrom(part); if (!part_char_matcher.matchesallof(asciichars)) { return false; } // no initial or final dashes or underscores. if (dash_matcher.matches(part.charat(0)) || dash_matcher.matches(part.charat(part.length() - 1))) { return false; } /* * note that we allow (in contravention of a strict interpretation of the relevant rfcs) domain * parts other than the last may begin with a digit (for example, "3com.com"). it's important to * disallow an initial digit in the last part; it's the only thing that stops an ipv4 numeric * address like 127.0.0.1 from looking like a valid domain name. */ if (isfinalpart && charmatcher.digit().matches(part.charat(0))) { return false; } return true; }
Original Comment: helper method for }. validates that one part of a domain name is valid.
Generated Comment: parses a string representation of the given string.
24.256574630737305
========================================
Code: private void extractapklib( artifact apklibartifact ) throws mojoexecutionexception { getunpackedlibhelper().extractapklib( apklibartifact ); // copy the assets to the the combinedassets folder. // add the apklib source and resource to the compile. // nb apklib sources are added to compilesourceroot because we may need to compile against them. // this means the apklib classes will be compiled into target/classes and packaged with this build. copyfolder( getunpackedlibassetsfolder( apklibartifact ), combinedassets ); final file apklibsourcefolder = getunpackedapklibsourcefolder( apklibartifact ); final list<string> resourceexclusions = arrays.aslist( "**/*.java", "**/*.aidl" ); projecthelper.addresource( project, apklibsourcefolder.getabsolutepath(), null, resourceexclusions ); project.addcompilesourceroot( apklibsourcefolder.getabsolutepath() ); }
Original Comment: extracts apklib and adds the assets and apklib sources and resources to the build.
Generated Comment: extracts compiled from the cp compiler.
23.989707946777344
========================================
Code: public void adddefaultheader(final string name, final string value) { validate.notempty(name, "header name cannot be empty"); validate.notnull(value, "header value cannot be null, use an empty string instead"); this.checkconfigurable(); this.defaultheaders.put(name, value); }
Original Comment: adds a default header to be added to every stub http response.
Generated Comment: adds the headers to the headers.
23.846609115600586
========================================
Code: public static schema getschema(final file xsd, final errorhandler errorhandler) throws saxexception { // create a new instance for an xsd-aware schemafactory final schemafactory schemafactory = schemafactory .newinstance(http_www_w3_org_2001_xml_schema); // set the errorhandler implementation. schemafactory.seterrorhandler(errorhandler); // get the custom xsd schema that describes // the required format for my xml files. return schemafactory.newschema(xsd); }
Original Comment: gets the schema.
Generated Comment: creates a xml object from the given namespace.
23.77509880065918
========================================
Code: @override protected formatwriter createwriter(final outputstream outputstream, final formatlogger logger) { try { return new dsmlformatwriter(outputstream); } catch (final ioexception e) { logger.logerror("could not create and intialise the dsml writer", e); } return null; }
Original Comment: create the ldap writer that will dump ldap entries to a dsml file.
Generated Comment: creates a new writer as a xml file.
23.688125610351562
========================================
Code: @override public volatileimage createcompatiblevolatileimage(int width, int height, imagecapabilities caps, int transparency) throws awtexception { if (img == null) { img = new bufferedimage(1, 1, bufferedimage.type_int_argb); gc = img.creategraphics().getdeviceconfiguration(); } return gc.createcompatiblevolatileimage(width, height, caps, transparency); }
Original Comment: returns a volatile image. this method is a workaround for a classcastexception that occurs on macosx when exporting a swing ui that uses the nimbus look and feel to svg.
Generated Comment: create a new image from a new image.
23.60519790649414
========================================
Code: private static void printstacktrace(printstream out, throwable err) { out.println(err.getclass().getname() + ": " + err.getmessage()); for (stacktraceelement ste : err.getstacktrace()) { out.println("\tat " + ste.tostring()); } if (err.getcause() != null) { out.print("caused by: "); printstacktrace(out, err.getcause()); } }
Original Comment: print a complete stack trace. this differs from throwable.printstacktrace() in that it always prints all of the trace.
Generated Comment: print out a message.
23.529924392700195
========================================
What's Next?
If you'd like to see how you can integrate this code comment summarizer model into the popular VSCode IDE, check out my video that goes over just that!