source
loss_func
loss_func (logits, labels)
Calculates the cross entropy loss for the model’s output and the labels.
logits
the model’s output
labels
the labels to calculate the cross entropy loss against
# test loss function
model = AutoModelForCausalLM.from_pretrained("gpt2" )
tokenizer = AutoTokenizer.from_pretrained("gpt2" )
inputs = tokenizer(
["Hello, my dog is cute" , "Hello, my dog is cute" ], return_tensors= "pt"
)
outputs = model(** inputs)
logits = outputs.logits
labels = inputs.input_ids
loss_func(logits, labels)
tensor([[2.3432, 3.7964, 6.6038, 1.7265, 5.4809],
[2.3432, 3.7964, 6.6038, 1.7265, 5.4809]], grad_fn=<ViewBackward0>)
source
get_counts
get_counts (model, tokenizer, batch, semantic_column:str,
stop_word_column:str, return_distributions:bool)
Returns the counts for the losses and tokens.
model
the model to use for predictions
tokenizer
the tokenizer to use for encoding
batch
the batch to use for predictions
semantic_column
str
the column to use for semantic predictions
stop_word_column
str
the column to use for stop word predictions
return_distributions
bool
whether to return the distributions
source
perplexed
perplexed (model:transformers.modeling_utils.PreTrainedModel,
dataset:datasets.arrow_dataset.Dataset, tokenizer:transformers
.tokenization_utils.PreTrainedTokenizer=None,
column:str='text', semantic_column:str=None,
stop_word_column:str=None, n_gram:int=1, batch_size:int=1,
num_proc:int=2, device:str='cuda', collate_fn=<function
default_data_collator>, pass_row:bool=False,
return_tokens:bool=False, return_distributions:bool=False,
compute_perplexity:bool=True)
Calculate the perplexity of a model on a dataset.
model
PreTrainedModel
The model to calculate the perplexity of.
dataset
Dataset
The dataset to calculate the perplexity on.
tokenizer
PreTrainedTokenizer
None
The tokenizer to use to tokenize the dataset. If not provided, the tokenizer associated with the model will be used.
column
str
text
The column of the dataset to calculate the perplexity on.
semantic_column
str
None
The column of the dataset to calculate the semantic perplexity on such as NER tags.
stop_word_column
str
None
The column of the dataset that contains boolean values indicating whether the token is a stop word.
n_gram
int
1
The n-gram to calculate the perplexity on.
batch_size
int
1
The batch size to use when calculating the perplexity.
num_proc
int
2
The number of processes to use when tokenizing the dataset.
device
str
cuda
The device to use when calculating the perplexity.
collate_fn
function
default_data_collator
The collate function to use when calculating the perplexity.
pass_row
bool
False
Whether to pass the row to the tokenizer.
return_tokens
bool
False
Whether to return the tokens counts along with the perplexity.
return_distributions
bool
False
Whether to return the perplexity distributions instead of the perplexity.
compute_perplexity
bool
True
Whether to compute the perplexity. If False, the cross entropy will be returned instead.
Perplexity per token
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M" )
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M" )
model.to(device)
dataset = load_dataset("wikitext" , "wikitext-2-raw-v1" , split= "test" ).select(range (50 ))
# filter out empty strings
dataset = dataset.filter (lambda x: len (x["text" ]) > 0 )
perplexity_cnt, token_cnt = perplexed(
model,
dataset,
tokenizer= tokenizer,
column= "text" ,
batch_size= 1 ,
device= device,
num_proc= 1 ,
return_tokens= True ,
)
assert len (perplexity_cnt) == len (token_cnt)
assert perplexity_cnt.keys() == token_cnt.keys()
cross_cnt, token_cnt = perplexed(
model,
dataset,
tokenizer= tokenizer,
column= "text" ,
batch_size= 1 ,
device= device,
num_proc= 1 ,
return_tokens= True ,
compute_perplexity= False ,
)
assert len (cross_cnt) == len (token_cnt)
assert cross_cnt.keys() == token_cnt.keys()
cross_cnt.most_common(10 )
[(' wired', 17.92612648010254),
(' shatter', 16.32363510131836),
(' Career', 15.21772575378418),
(' Early', 14.70047664642334),
(' Television', 14.659582138061523),
(' Daylight', 14.56997299194336),
(' unrecogn', 14.364179611206055),
(' @', 14.307954322208058),
(' Chou', 14.180266380310059),
(' advisers', 13.927596092224121)]
cross_cnt.most_common()[- 10 :]
[('mers', 0.03539723251014948),
('mith', 0.018193976022303104),
('t', 0.016906073316931725),
(' than', 0.009314415045082569),
('jiang', 0.005416479427367449),
('ian', 0.004262291360646486),
('aire', 0.002999095479026437),
('el', 0.0017088347813114524),
('ights', 0.001490435330197215),
('sworth', 0.0009158230968751013)]
# cross entropy of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10 )]
for token in tokens:
print (f"' { token} ': { cross_cnt[token]} " )
'<|endoftext|>': 10.327683209001043
' the': 1.5023754525995046
',': 2.799564078589466
'.': 2.2654987903962653
' "': 2.2530801612883806
' in': 2.0132113315057065
' of': 1.2379778898500193
' a': 2.107695746828209
' =': 3.9336307379530697
' and': 1.6605487003922463
Perplexity per semantic type
The following cells contain the code for calculating the perplexity per semantic type of a tokenizer for aligning the AST of a program with the BPE of a language model’s tokenizer.
! pip install - U code_tokenizers
! download_grammars
from code_tokenizers.core import CodeTokenizer
def code_collator(batch):
merged_ast = []
for b in batch:
merged_ast.append(b.pop("merged_ast" ))
batch = default_data_collator(batch)
batch["merged_ast" ] = merged_ast
return batch
model_name = "codeparrot/codeparrot-small"
py_tokenizer = CodeTokenizer.from_pretrained(model_name, "python" )
py_tokenizer.tokenizer.pad_token = py_tokenizer.tokenizer.eos_token
py_tokenizer.pad_token = py_tokenizer.tokenizer.pad_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
dataset = load_dataset("codeparrot/codeparrot-clean-valid" , split= "train" ).select(
range (15 )
)
cross_cnt, token_cnt = perplexed(
model,
dataset,
tokenizer= py_tokenizer,
column= "content" ,
semantic_column= "merged_ast" ,
stop_word_column= "is_builtins" ,
batch_size= 1 ,
num_proc= 1 ,
device= device,
collate_fn= code_collator,
return_tokens= True ,
compute_perplexity= False ,
)
assert len (cross_cnt) == len (token_cnt)
assert cross_cnt.keys() == token_cnt.keys()
cross_cnt.most_common(10 )
[('reports', 15.318881034851074),
('Double', 15.236268043518066),
('BLANK', 15.137480735778809),
('148', 14.469829559326172),
('BD', 13.819499969482422),
('year', 13.65689468383789),
(' filesystem', 13.625283241271973),
('CO', 13.59871768951416),
('Pure', 13.172009468078613),
('customize', 13.098344802856445)]
token_cnt.most_common(10 )
[('<|endoftext|>', 3951),
('<module -> comment>', 1479),
('< N/A >', 1123),
('<attribute -> identifier>', 1019),
('<argument_list -> string>', 728),
('<expression_statement -> string>', 677),
('.', 608),
('<dotted_name -> identifier>', 608),
('_', 434),
('\n', 391)]
# perplexity of the most common tokens
tokens = [token for token, _ in token_cnt.most_common(10 )]
for token in tokens:
print (f"' { token} ': { perplexity_cnt[token]} " )
'<|endoftext|>': 30567.21875
'<module -> comment>': 0
'< N/A >': 0
'<attribute -> identifier>': 0
'<argument_list -> string>': 0
'<expression_statement -> string>': 0
'.': 9.635930061340332
'<dotted_name -> identifier>': 0
'_': 0
'
': 3.0456223487854004