问题
I came across this page
1) I would like to get sentence level embedding (embedding given by [CLS]
token) after the fine tuning is done. How could I do it?
2) I also noticed that the code on that page takes a lot of time to return results on the test data. Why is that? When i trained the model it took less time as compared to when i tried to get test predictions. From the code on that page, I didnt use below blocks of the code
test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None,
text_a = x[DATA_COLUMN],
text_b = None,
label = x[LABEL_COLUMN]), axis = 1
test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)
test_input_fn = run_classifier.input_fn_builder(
features=test_features,
seq_length=MAX_SEQ_LENGTH,
is_training=False,
drop_remainder=False)
estimator.evaluate(input_fn=test_input_fn, steps=None)
Rather I just used below function on my entire test data
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]
3) how could i get probability of prediction. is there a way to use keras predict
method?
update1
question 2 update -
could you test on 20000 training examples using getPrediction
function?....it takes much longer time for me..even more than the time took to train model on 20000 examples.
回答1:
1) From BERT documentation
The output dictionary contains:
pooled_output: pooled output of the entire sequence with shape [batch_size, hidden_size]. sequence_output: representations of every token in the input sequence with shape [batch_size, max_sequence_length, hidden_size].
I've added pooled_output
vector which corresponds to the CLS vector.
3) You receive log probabilities. Just apply softmax
to get normal probabilities.
Now all it is left to do is for model to report it. I have left the log probs, but they are not necessary anymore.
See the code changes:
def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
num_labels):
"""Creates a classification model."""
bert_module = hub.Module(
BERT_MODEL_HUB,
trainable=True)
bert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids)
bert_outputs = bert_module(
inputs=bert_inputs,
signature="tokens",
as_dict=True)
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
output_layer = bert_outputs["pooled_output"]
pooled_output = output_layer
hidden_size = output_layer.shape[-1].value
# Create our own layer to tune for politeness data.
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
# Dropout helps prevent overfitting
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
# Convert labels into one-hot encoding
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
# If we're predicting, we want predicted labels and the probabiltiies.
if is_predicting:
return (predicted_labels, log_probs, probs, pooled_output)
# If we're train/eval, compute loss between predicted and actual label
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, predicted_labels, log_probs, probs, pooled_output)
Now in the model_fn_builder()
add support for those values:
# this should be changed in both places
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
# return dictionary of all the values you wanted
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
Adjust getPrediction()
accordingly and in the end your predictions will look like this:
('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32), <= Probability
array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
'Negative', <= Label
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
...
and this is 768-d [CLS] vector (sentence embedding).
Regarding 2): At my end training took about 5 minutes and test about 40 seconds. Very reasonable.
UPDATE
For 20k samples it took 12:48 to train and 2:07 minutes to test.
For 10k samples timings are 8:40 and 1:07 respectively.
回答2:
Sure, here is the rest of changes:
# model_fn_builder actually creates our model function
# using the passed parameters for num_labels, learning_rate, etc.
def model_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)
# TRAIN and EVAL
if not is_predicting:
(loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
train_op = bert.optimization.create_optimizer(
loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
# Calculate evaluation metrics.
def metric_fn(label_ids, predicted_labels):
accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
f1_score = tf.contrib.metrics.f1_score(
label_ids,
predicted_labels)
auc = tf.metrics.auc(
label_ids,
predicted_labels)
recall = tf.metrics.recall(
label_ids,
predicted_labels)
precision = tf.metrics.precision(
label_ids,
predicted_labels)
true_pos = tf.metrics.true_positives(
label_ids,
predicted_labels)
true_neg = tf.metrics.true_negatives(
label_ids,
predicted_labels)
false_pos = tf.metrics.false_positives(
label_ids,
predicted_labels)
false_neg = tf.metrics.false_negatives(
label_ids,
predicted_labels)
return {
"eval_accuracy": accuracy,
"f1_score": f1_score,
"auc": auc,
"precision": precision,
"recall": recall,
"true_positives": true_pos,
"true_negatives": true_neg,
"false_positives": false_pos,
"false_negatives": false_neg
}
eval_metrics = metric_fn(label_ids, predicted_labels)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)
else:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
# Return the actual model function in the closure
return model_fn
def getPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction in zip(in_sentences, predictions)]
and the first output (others is cut off bc 30K symbols limit on the answer):
[('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32),
array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
'Negative',
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
-0.9999866 , 0.5820049 , 0.3257555 , -0.81900954, -0.8326617 ,
0.87788117, -0.7791749 , 0.11098853, 0.67873836, 0.9999771 ,
0.9833652 , -0.8420576 , 0.83076835, 0.37272754, 0.8667175 ,
0.792386 , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
0.55752313, 1. , -0.72632766, -0.4752956 , -0.9999852 ,
-0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
0.92974335, -0.8601105 , -0.8113003 , 0.7660112 , 0.9313508 ,
0.21427669, -0.45660907, 0.99970686, 0.56852764, -0.9997675 ,
-0.9999096 , 0.8247045 , 0.7205424 , 0.47192624, -0.7523966 ,
-0.9588541 , -0.48866934, 0.9809366 , -0.07110611, -0.99886 ,
-0.63922834, -0.68144 , -1. , 0.8531816 , 0.26078308,
-0.99898577, -0.99968046, 0.6711601 , 0.99857473, -0.99990964,
1. , -0.97127694, -0.10644457, 0.46306637, -0.32486317,
-0.68167734, 0.43291137, -0.996574 , 0.05164305, 0.9897354 ,
0.93853104, 0.94800174, 0.9995697 , 0.6532897 , 0.93846226,
-0.6281378 , 0.5574107 , 0.725278 , 0.74160355, -0.6486919 ,
0.88869256, 0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
0.17409436, 0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
0.5546853 , 0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
0.94367695, 0.6834396 , -0.72266734, 0.99376386, -0.76821744,
0.4485644 , 0.99982166, 1. , 0.9260674 , 0.9759094 ,
0.9397613 , 0.8128903 , -0.7918152 , 0.30299878, -0.95160294,
0.25385544, -0.57780135, -0.9999994 , 0.9168113 , -0.36585295,
0.9798102 , 0.95976156, -0.99428 , 0.6471789 , -0.9948078 ,
-0.9686591 , 0.93615085, -0.11481134, 0.87566274, -0.91601896,
0.9952683 , 0.26532048, 0.99861896, 0.79298306, 0.5872364 ,
-0.56314534, 0.96794534, 0.9999797 , 0.9879324 , 0.5003342 ,
0.9516269 , -0.8878316 , -0.9665091 , -0.88037425, 0.8356687 ,
-0.71543014, -0.99985015, -0.9414574 , 0.8681497 , 0.950698 ,
-0.8007153 , 0.78748596, 0.9999305 , 0.40210736, 0.4856055 ,
-0.9390776 , 0.63564163, -0.85989815, -0.8421344 , -0.99436 ,
0.78081733, -0.97038007, 0.39290914, 0.7834218 , 0.88715357,
-0.03653741, 0.99126273, -0.96559966, 0.11924513, -0.99363935,
-0.9901692 , 0.963858 , 0.5713922 , 0.5676979 , 0.69982123,
0.858003 , 0.9983819 , -0.87965024, 0.46213093, -0.3256273 ,
0.77337253, 0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
-0.93148243, 0.09674019, 0.09448949, -0.7453027 , -0.78955775,
-0.6304773 , -0.5597632 , 0.992308 , 0.7769483 , 0.04146893,
-0.15876745, -0.7682887 , -0.5231416 , 0.7871302 , 0.9503481 ,
-0.9607153 , 0.99047405, -0.9948017 , -0.82257754, 0.9990552 ,
0.79346406, -0.78624016, 0.8760266 , -0.7855991 , 0.13444276,
-0.7183107 , -0.9999819 , 0.7019429 , -0.918913 , -0.6569654 ,
0.9998794 , -0.33805153, -0.9427715 , 0.10419375, -0.94257164,
0.9187495 , -0.9994855 , -0.99979955, -0.9277688 , 0.6353426 ,
0.9994905 , 0.90688777, 0.9992008 , 0.7817533 , -0.9996674 ,
-0.999962 , -0.13310781, -0.82505953, 0.9997485 , 0.82616794,
-0.999998 , 0.45386457, 0.6069964 , 0.52272975, 0.8811922 ,
0.52668494, -0.9994814 , -0.21601789, -0.99882716, 0.90246916,
0.94196504, 0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
0.7727592 , 0.9936947 , 0.98021245, -0.77723926, -0.785372 ,
0.5150317 , 0.9983137 , -0.7461883 , 0.3311537 , -0.63709795,
-0.6487831 , -0.9173727 , 0.9997706 , -0.9999893 , -1. ,
0.60389155, -0.6516268 , -0.95422006, 1. , 0.09109057,
-0.99999994, 0.99998957, 1. , -0.19451752, 0.94624877,
-0.2761865 , 1. , 0.52399474, 0.70230734, 0.5218801 ,
-0.99716544, -0.70075685, -0.99992603, 1. , -0.9785006 ,
0.22457084, -0.5356722 , -0.9991887 , 0.7062409 , 0.66816545,
-0.90308225, -0.8084922 , 0.50301254, -0.7062079 , 0.9998321 ,
0.9823206 , 0.9984027 , 0.9948857 , -1. , -0.7067878 ,
0.975454 , 0.87161005, -0.9882297 , 0.8296374 , -0.88615334,
0.4316883 , 0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
-0.84212875, 0.78632677, -0.5131366 , -0.996949 , -0.75479275,
-0.06342169, 0.92238575, 0.66769385, 0.9926053 , -0.78391105,
0.9976865 , 0.07086544, 0.34079495, 0.69730175, -0.99970955,
-1. , -0.9860551 , 0.89584446, -0.96889114, -0.90435815,
0.944296 , -1. , -0.9931756 , -0.7014334 , -0.6742562 ,
-0.96786517, 0.848328 , 0.8903087 , -0.9998633 , 0.73993397,
0.99345684, 0.9691821 , 0.87563246, -0.6073146 , -0.9999999 ,
0.90763575, 0.30225936, -0.47824544, 0.7179979 , 0.9450465 ,
0.9715953 , -0.5422173 , 0.99995065, -0.5920663 , 0.92390317,
-0.9670669 , -0.3623574 , 0.74825 , -0.7817521 , 0.9888685 ,
-0.7653631 , -0.8933355 , 0.9481424 , 0.97803396, -0.9999731 ,
-0.89597356, 0.35502487, -0.7190486 , 0.30777818, 0.55025375,
0.6365793 , -0.99094397, -1. , 0.93482614, -0.99970514,
0.98721176, 0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
0.57238674, 0.97475344, -0.9963499 , 0.98476464, 0.40495875,
-0.7001948 , -0.40898973, 0.61900675, -1. , -0.9371812 ,
-0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245 ,
-0.99991447, -0.5872595 , 0.5835767 , 0.7003338 , -0.9761974 ,
0.99995846, 0.33676207, 0.9079994 , -0.76412004, -0.7648706 ,
0.68863285, 0.43983305, 0.74911463, -0.99995685, -0.6692586 ,
-0.45761266, -0.9980771 , -1. , 0.31244457, -0.8834693 ,
0.9388263 , -0.987405 , 1. , 0.9512058 , 0.23448633,
0.37940192, 0.99989796, 0.8402514 , -0.84526414, 0.7378776 ,
-0.9996204 , -0.99434114, 0.9987527 , 0.5569713 , 0.99648696,
-0.9933159 , -0.13116199, 0.9999992 , 0.9642579 , -0.48285434,
-0.97517425, 0.7185596 , 0.5286405 , 0.9902838 , 0.7796022 ,
-0.80703837, 0.2376029 , 0.534117 , -0.9999413 , 0.99828076,
0.9998345 , 0.93249476, 0.3620626 , 0.7567034 , -0.9222681 ,
0.97832036, 0.9999682 , 0.6433209 , -1. , 0.9268615 ,
-0.9999511 , -0.9145363 , -0.9213852 , 0.7606066 , -0.5501025 ,
-0.99999434, -0.7783993 , 0.9999771 , 0.99980384, 0.987094 ,
0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
-0.9919206 , -0.49190572, 0.70230234, -0.31277484, -0.99999964,
0.828591 , 0.6363776 , 0.86796165, 0.81575817, 0.7782955 ,
0.9436437 , -1. , -0.7509046 , -0.9946139 , -0.6647415 ,
0.999543 , 0.9312092 , -1. , 0.5639159 , 0.9482462 ,
-0.9289936 , -0.9678435 , 0.60937124, -0.987818 , 0.5511619 ,
0.75886583, -0.48466644, -0.71833754, 0.8042149 , 0.9154103 ,
-0.8177468 , 0.7195895 , -0.82283056, 0.24990956, -1. ,
0.7729634 , 0.84048635, 0.7989596 , 0.9469012 , -0.9898951 ,
-0.92565274, 0.74726975, 0.78213847, -0.672894 , -0.58831286,
-0.8039038 , -0.72197783, 0.5289216 , -0.9998796 , -0.9904479 ,
0.9996592 , -0.28984115, 0.23964961, -0.7427149 , -0.662416 ,
-1. , -0.5538268 , -0.9945287 , -0.63471127, 0.5896127 ,
-0.48429146, 0.9976076 , -0.94329506, -0.49143887, 0.7695602 ,
0.8638134 , -0.82130384, 0.50105464, 0.9336961 , -0.24716294,
-0.6922282 , -0.02228704, 0.75649065, 0.82303154, -0.30867255,
-0.9602714 , 0.64568967, 0.314201 , -0.4811752 , 0.27952817,
0.9227022 , 0.88095886, 0.89470226, 1. , -0.19237158,
1. , -0.991253 , -0.9991121 , 0.5637482 , -0.75780976,
-0.3904836 , -0.9881965 , -0.2912058 , 0.9998215 , 0.9869475 ,
-0.12784953, 0.81566185, 0.9787118 , -0.17835459, -0.7027824 ,
0.72269535, -0.18194303, 0.9968796 , 0.03490257, 0.7751488 ,
-1. , -0.7761089 , 0.85105944, 0.9968074 , -0.8156342 ,
0.5300792 , -1. , 0.99626255, -0.7515625 , -0.6672005 ,
0.9792111 , 0.8660997 , -0.69161206, 0.32184905, 0.9071073 ,
0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
0.70707524, 0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
-0.94375473, -0.91838425, 0.64272994, 0.9375524 , 0.6609169 ,
-0.88743365, -0.9534722 , -0.47888806, -1. , -0.5251781 ,
0.8274516 , 0.9326824 , 0.8961964 , 0.5295862 , 0.43714878,
-0.7488347 , -0.75295556, -0.5187054 , 0.75924635, -0.7862662 ,
0.99981725, -0.80290836, 0.97651815, 0.99763787, -0.29619345,
-0.1252967 , 0.33606276, -0.65137684, -0.9680231 , 0.77586985,
0.22347753, 0.27245504, -0.07826214, -0.8383849 , -0.85373163,
1. , -0.4563588 , -0.91339815, -0.9999861 , 0.66063935,
-0.985843 , -0.7818757 , -0.7000497 , -0.6840764 , 0.9995542 ,
0.60819125, 0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
-0.8771755 , 0.71411085, 0.8113569 , 0.9974196 , -0.75211936,
0.63400257, -0.8272833 , 0.99780786, 0.9965285 , 0.59551436,
-0.9876875 , -0.04439292, 0.9939223 , 0.9993717 , -0.9965501 ,
-0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
-0.95355797, -0.67561924, 0.9997761 , -0.85473967, 0.998495 ,
-0.95756954, 0.633171 , 0.4570475 , -0.5316367 , -0.9663824 ,
0.9567106 , -0.45497724, 0.12964879, 0.9964744 , -0.9711668 ,
0.69636106, -0.9178346 , 0.8313186 , 0.69686604, 0.8141587 ,
-0.33600506, 0.94798595, 0.8800869 , 0.15029034, -0.91185665,
0.6322724 , -0.9971475 , 0.71948224, 0.9695236 , 0.84242374,
0.99995124, 0.5982563 , -0.98341423, 0.61301434, 0.9997318 ,
-0.9981808 , -0.65651804, -0.8484874 , -0.9961815 , 0.9030814 ,
0.87141925, 0.8897381 , -0.92870414, 0.07134341, 0.8739935 ,
0.91630197, -0.9465984 , -0.59741104, -1. , 0.9989559 ,
0.99991184, 0.67439264, 0.92025673, -0.60730827, 0.8362061 ,
1. , -0.70801497, 0.9883806 , -0.9984141 , 0.9919259 ,
-0.998869 , 0.9976203 , 0.9888036 , 0.8556838 , -0.9722744 ,
-0.99810714, 0.8182833 , 0.98808485, 0.6643728 , 0.99212515,
-0.99988 , 0.26405996, 0.93139845, 0.99021816, 0.6846886 ,
0.9986462 , 0.92254627, -0.6406982 ], dtype=float32)),
('The acting was a bit lacking',
array([0.9921152 , 0.00788479], dtype=float32),
array([-0.00791603, -4.842819 ], dtype=float32),
'Negative',
array([ 0.67417824, 0.8235167 , 0.99999565, -0.8565971 , -0.99499583,
0.8219966 , -0.9185583 , -0.5234593 , 0.99962074, 0.99999714,
0.9507927 , -0.9996754 , 0.22211392, -0.99826247, 0.7562492 ,
0.93803996, 0.82738185, 0.4773049 , -0.73478544, 0.85207295,
来源:https://stackoverflow.com/questions/60767089/bert-get-sentence-level-embedding-after-fine-tuning