BERT get sentence level embedding after fine tuning

南楼画角 提交于 2020-04-13 06:34:08

问题


I came across this page

1) I would like to get sentence level embedding (embedding given by [CLS] token) after the fine tuning is done. How could I do it?

2) I also noticed that the code on that page takes a lot of time to return results on the test data. Why is that? When i trained the model it took less time as compared to when i tried to get test predictions. From the code on that page, I didnt use below blocks of the code

test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None, 
                                                                       text_a = x[DATA_COLUMN], 
                                                                       text_b = None, 
                                                                       label = x[LABEL_COLUMN]), axis = 1

test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)

test_input_fn = run_classifier.input_fn_builder(
        features=test_features,
        seq_length=MAX_SEQ_LENGTH,
        is_training=False,
        drop_remainder=False)

estimator.evaluate(input_fn=test_input_fn, steps=None)

Rather I just used below function on my entire test data

def getPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]

3) how could i get probability of prediction. is there a way to use keras predict method?

update1

question 2 update - could you test on 20000 training examples using getPrediction function?....it takes much longer time for me..even more than the time took to train model on 20000 examples.


回答1:


1) From BERT documentation

The output dictionary contains:

pooled_output: pooled output of the entire sequence with shape [batch_size, hidden_size]. sequence_output: representations of every token in the input sequence with shape [batch_size, max_sequence_length, hidden_size].

I've added pooled_output vector which corresponds to the CLS vector.

3) You receive log probabilities. Just apply softmax to get normal probabilities.

Now all it is left to do is for model to report it. I have left the log probs, but they are not necessary anymore.

See the code changes:

def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
                 num_labels):
  """Creates a classification model."""

  bert_module = hub.Module(
      BERT_MODEL_HUB,
      trainable=True)
  bert_inputs = dict(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids)
  bert_outputs = bert_module(
      inputs=bert_inputs,
      signature="tokens",
      as_dict=True)

  # Use "pooled_output" for classification tasks on an entire sentence.
  # Use "sequence_outputs" for token-level output.
  output_layer = bert_outputs["pooled_output"]

  pooled_output = output_layer

  hidden_size = output_layer.shape[-1].value

  # Create our own layer to tune for politeness data.
  output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "output_bias", [num_labels], initializer=tf.zeros_initializer())

  with tf.variable_scope("loss"):

    # Dropout helps prevent overfitting
    output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)

    # Convert labels into one-hot encoding
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
    # If we're predicting, we want predicted labels and the probabiltiies.
    if is_predicting:
      return (predicted_labels, log_probs, probs, pooled_output)

    # If we're train/eval, compute loss between predicted and actual label
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, predicted_labels, log_probs, probs, pooled_output)

Now in the model_fn_builder() add support for those values:

  # this should be changed in both places
  (predicted_labels, log_probs, probs, pooled_output) = create_model(
    is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

  # return dictionary of all the values you wanted
  predictions = {
      'log_probabilities': log_probs,
      'probabilities': probs,
      'labels': predicted_labels,
      'pooled_output': pooled_output
  }

Adjust getPrediction() accordingly and in the end your predictions will look like this:

('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),  <= Probability
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
  'Negative', <= Label
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
        ...
        and this is 768-d [CLS] vector (sentence embedding).    

Regarding 2): At my end training took about 5 minutes and test about 40 seconds. Very reasonable.

UPDATE

For 20k samples it took 12:48 to train and 2:07 minutes to test.

For 10k samples timings are 8:40 and 1:07 respectively.




回答2:


Sure, here is the rest of changes:

# model_fn_builder actually creates our model function
# using the passed parameters for num_labels, learning_rate, etc.
def model_fn_builder(num_labels, learning_rate, num_train_steps,
                     num_warmup_steps):
  """Returns `model_fn` closure for TPUEstimator."""
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    label_ids = features["label_ids"]

    is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)

    # TRAIN and EVAL
    if not is_predicting:

      (loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      train_op = bert.optimization.create_optimizer(
          loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)

      # Calculate evaluation metrics. 
      def metric_fn(label_ids, predicted_labels):
        accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
        f1_score = tf.contrib.metrics.f1_score(
            label_ids,
            predicted_labels)
        auc = tf.metrics.auc(
            label_ids,
            predicted_labels)
        recall = tf.metrics.recall(
            label_ids,
            predicted_labels)
        precision = tf.metrics.precision(
            label_ids,
            predicted_labels) 
        true_pos = tf.metrics.true_positives(
            label_ids,
            predicted_labels)
        true_neg = tf.metrics.true_negatives(
            label_ids,
            predicted_labels)   
        false_pos = tf.metrics.false_positives(
            label_ids,
            predicted_labels)  
        false_neg = tf.metrics.false_negatives(
            label_ids,
            predicted_labels)
        return {
            "eval_accuracy": accuracy,
            "f1_score": f1_score,
            "auc": auc,
            "precision": precision,
            "recall": recall,
            "true_positives": true_pos,
            "true_negatives": true_neg,
            "false_positives": false_pos,
            "false_negatives": false_neg
        }

      eval_metrics = metric_fn(label_ids, predicted_labels)

      if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(mode=mode,
          loss=loss,
          train_op=train_op)
      else:
          return tf.estimator.EstimatorSpec(mode=mode,
            loss=loss,
            eval_metric_ops=eval_metrics)
    else:
      (predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      predictions = {
          'log_probabilities': log_probs,
          'probabilities': probs,
          'labels': predicted_labels,
          'pooled_output': pooled_output
      }
      return tf.estimator.EstimatorSpec(mode, predictions=predictions)

  # Return the actual model function in the closure
  return model_fn


def getPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction in zip(in_sentences, predictions)]

and the first output (others is cut off bc 30K symbols limit on the answer):

[('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
  'Negative',
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
         -0.9999866 ,  0.5820049 ,  0.3257555 , -0.81900954, -0.8326617 ,
          0.87788117, -0.7791749 ,  0.11098853,  0.67873836,  0.9999771 ,
          0.9833652 , -0.8420576 ,  0.83076835,  0.37272754,  0.8667175 ,
          0.792386  , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
          0.55752313,  1.        , -0.72632766, -0.4752956 , -0.9999852 ,
         -0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
          0.92974335, -0.8601105 , -0.8113003 ,  0.7660112 ,  0.9313508 ,
          0.21427669, -0.45660907,  0.99970686,  0.56852764, -0.9997675 ,
         -0.9999096 ,  0.8247045 ,  0.7205424 ,  0.47192624, -0.7523966 ,
         -0.9588541 , -0.48866934,  0.9809366 , -0.07110611, -0.99886   ,
         -0.63922834, -0.68144   , -1.        ,  0.8531816 ,  0.26078308,
         -0.99898577, -0.99968046,  0.6711601 ,  0.99857473, -0.99990964,
          1.        , -0.97127694, -0.10644457,  0.46306637, -0.32486317,
         -0.68167734,  0.43291137, -0.996574  ,  0.05164305,  0.9897354 ,
          0.93853104,  0.94800174,  0.9995697 ,  0.6532897 ,  0.93846226,
         -0.6281378 ,  0.5574107 ,  0.725278  ,  0.74160355, -0.6486919 ,
          0.88869256,  0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
          0.17409436,  0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
          0.5546853 ,  0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
          0.94367695,  0.6834396 , -0.72266734,  0.99376386, -0.76821744,
          0.4485644 ,  0.99982166,  1.        ,  0.9260674 ,  0.9759094 ,
          0.9397613 ,  0.8128903 , -0.7918152 ,  0.30299878, -0.95160294,
          0.25385544, -0.57780135, -0.9999994 ,  0.9168113 , -0.36585295,
          0.9798102 ,  0.95976156, -0.99428   ,  0.6471789 , -0.9948078 ,
         -0.9686591 ,  0.93615085, -0.11481134,  0.87566274, -0.91601896,
          0.9952683 ,  0.26532048,  0.99861896,  0.79298306,  0.5872364 ,
         -0.56314534,  0.96794534,  0.9999797 ,  0.9879324 ,  0.5003342 ,
          0.9516269 , -0.8878316 , -0.9665091 , -0.88037425,  0.8356687 ,
         -0.71543014, -0.99985015, -0.9414574 ,  0.8681497 ,  0.950698  ,
         -0.8007153 ,  0.78748596,  0.9999305 ,  0.40210736,  0.4856055 ,
         -0.9390776 ,  0.63564163, -0.85989815, -0.8421344 , -0.99436   ,
          0.78081733, -0.97038007,  0.39290914,  0.7834218 ,  0.88715357,
         -0.03653741,  0.99126273, -0.96559966,  0.11924513, -0.99363935,
         -0.9901692 ,  0.963858  ,  0.5713922 ,  0.5676979 ,  0.69982123,
          0.858003  ,  0.9983819 , -0.87965024,  0.46213093, -0.3256273 ,
          0.77337253,  0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
         -0.93148243,  0.09674019,  0.09448949, -0.7453027 , -0.78955775,
         -0.6304773 , -0.5597632 ,  0.992308  ,  0.7769483 ,  0.04146893,
         -0.15876745, -0.7682887 , -0.5231416 ,  0.7871302 ,  0.9503481 ,
         -0.9607153 ,  0.99047405, -0.9948017 , -0.82257754,  0.9990552 ,
          0.79346406, -0.78624016,  0.8760266 , -0.7855991 ,  0.13444276,
         -0.7183107 , -0.9999819 ,  0.7019429 , -0.918913  , -0.6569654 ,
          0.9998794 , -0.33805153, -0.9427715 ,  0.10419375, -0.94257164,
          0.9187495 , -0.9994855 , -0.99979955, -0.9277688 ,  0.6353426 ,
          0.9994905 ,  0.90688777,  0.9992008 ,  0.7817533 , -0.9996674 ,
         -0.999962  , -0.13310781, -0.82505953,  0.9997485 ,  0.82616794,
         -0.999998  ,  0.45386457,  0.6069964 ,  0.52272975,  0.8811922 ,
          0.52668494, -0.9994814 , -0.21601789, -0.99882716,  0.90246916,
          0.94196504,  0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
          0.7727592 ,  0.9936947 ,  0.98021245, -0.77723926, -0.785372  ,
          0.5150317 ,  0.9983137 , -0.7461883 ,  0.3311537 , -0.63709795,
         -0.6487831 , -0.9173727 ,  0.9997706 , -0.9999893 , -1.        ,
          0.60389155, -0.6516268 , -0.95422006,  1.        ,  0.09109057,
         -0.99999994,  0.99998957,  1.        , -0.19451752,  0.94624877,
         -0.2761865 ,  1.        ,  0.52399474,  0.70230734,  0.5218801 ,
         -0.99716544, -0.70075685, -0.99992603,  1.        , -0.9785006 ,
          0.22457084, -0.5356722 , -0.9991887 ,  0.7062409 ,  0.66816545,
         -0.90308225, -0.8084922 ,  0.50301254, -0.7062079 ,  0.9998321 ,
          0.9823206 ,  0.9984027 ,  0.9948857 , -1.        , -0.7067878 ,
          0.975454  ,  0.87161005, -0.9882297 ,  0.8296374 , -0.88615334,
          0.4316883 ,  0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
         -0.84212875,  0.78632677, -0.5131366 , -0.996949  , -0.75479275,
         -0.06342169,  0.92238575,  0.66769385,  0.9926053 , -0.78391105,
          0.9976865 ,  0.07086544,  0.34079495,  0.69730175, -0.99970955,
         -1.        , -0.9860551 ,  0.89584446, -0.96889114, -0.90435815,
          0.944296  , -1.        , -0.9931756 , -0.7014334 , -0.6742562 ,
         -0.96786517,  0.848328  ,  0.8903087 , -0.9998633 ,  0.73993397,
          0.99345684,  0.9691821 ,  0.87563246, -0.6073146 , -0.9999999 ,
          0.90763575,  0.30225936, -0.47824544,  0.7179979 ,  0.9450465 ,
          0.9715953 , -0.5422173 ,  0.99995065, -0.5920663 ,  0.92390317,
         -0.9670669 , -0.3623574 ,  0.74825   , -0.7817521 ,  0.9888685 ,
         -0.7653631 , -0.8933355 ,  0.9481424 ,  0.97803396, -0.9999731 ,
         -0.89597356,  0.35502487, -0.7190486 ,  0.30777818,  0.55025375,
          0.6365793 , -0.99094397, -1.        ,  0.93482614, -0.99970514,
          0.98721176,  0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
          0.57238674,  0.97475344, -0.9963499 ,  0.98476464,  0.40495875,
         -0.7001948 , -0.40898973,  0.61900675, -1.        , -0.9371812 ,
         -0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245  ,
         -0.99991447, -0.5872595 ,  0.5835767 ,  0.7003338 , -0.9761974 ,
          0.99995846,  0.33676207,  0.9079994 , -0.76412004, -0.7648706 ,
          0.68863285,  0.43983305,  0.74911463, -0.99995685, -0.6692586 ,
         -0.45761266, -0.9980771 , -1.        ,  0.31244457, -0.8834693 ,
          0.9388263 , -0.987405  ,  1.        ,  0.9512058 ,  0.23448633,
          0.37940192,  0.99989796,  0.8402514 , -0.84526414,  0.7378776 ,
         -0.9996204 , -0.99434114,  0.9987527 ,  0.5569713 ,  0.99648696,
         -0.9933159 , -0.13116199,  0.9999992 ,  0.9642579 , -0.48285434,
         -0.97517425,  0.7185596 ,  0.5286405 ,  0.9902838 ,  0.7796022 ,
         -0.80703837,  0.2376029 ,  0.534117  , -0.9999413 ,  0.99828076,
          0.9998345 ,  0.93249476,  0.3620626 ,  0.7567034 , -0.9222681 ,
          0.97832036,  0.9999682 ,  0.6433209 , -1.        ,  0.9268615 ,
         -0.9999511 , -0.9145363 , -0.9213852 ,  0.7606066 , -0.5501025 ,
         -0.99999434, -0.7783993 ,  0.9999771 ,  0.99980384,  0.987094  ,
          0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
         -0.9919206 , -0.49190572,  0.70230234, -0.31277484, -0.99999964,
          0.828591  ,  0.6363776 ,  0.86796165,  0.81575817,  0.7782955 ,
          0.9436437 , -1.        , -0.7509046 , -0.9946139 , -0.6647415 ,
          0.999543  ,  0.9312092 , -1.        ,  0.5639159 ,  0.9482462 ,
         -0.9289936 , -0.9678435 ,  0.60937124, -0.987818  ,  0.5511619 ,
          0.75886583, -0.48466644, -0.71833754,  0.8042149 ,  0.9154103 ,
         -0.8177468 ,  0.7195895 , -0.82283056,  0.24990956, -1.        ,
          0.7729634 ,  0.84048635,  0.7989596 ,  0.9469012 , -0.9898951 ,
         -0.92565274,  0.74726975,  0.78213847, -0.672894  , -0.58831286,
         -0.8039038 , -0.72197783,  0.5289216 , -0.9998796 , -0.9904479 ,
          0.9996592 , -0.28984115,  0.23964961, -0.7427149 , -0.662416  ,
         -1.        , -0.5538268 , -0.9945287 , -0.63471127,  0.5896127 ,
         -0.48429146,  0.9976076 , -0.94329506, -0.49143887,  0.7695602 ,
          0.8638134 , -0.82130384,  0.50105464,  0.9336961 , -0.24716294,
         -0.6922282 , -0.02228704,  0.75649065,  0.82303154, -0.30867255,
         -0.9602714 ,  0.64568967,  0.314201  , -0.4811752 ,  0.27952817,
          0.9227022 ,  0.88095886,  0.89470226,  1.        , -0.19237158,
          1.        , -0.991253  , -0.9991121 ,  0.5637482 , -0.75780976,
         -0.3904836 , -0.9881965 , -0.2912058 ,  0.9998215 ,  0.9869475 ,
         -0.12784953,  0.81566185,  0.9787118 , -0.17835459, -0.7027824 ,
          0.72269535, -0.18194303,  0.9968796 ,  0.03490257,  0.7751488 ,
         -1.        , -0.7761089 ,  0.85105944,  0.9968074 , -0.8156342 ,
          0.5300792 , -1.        ,  0.99626255, -0.7515625 , -0.6672005 ,
          0.9792111 ,  0.8660997 , -0.69161206,  0.32184905,  0.9071073 ,
          0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
          0.70707524,  0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
         -0.94375473, -0.91838425,  0.64272994,  0.9375524 ,  0.6609169 ,
         -0.88743365, -0.9534722 , -0.47888806, -1.        , -0.5251781 ,
          0.8274516 ,  0.9326824 ,  0.8961964 ,  0.5295862 ,  0.43714878,
         -0.7488347 , -0.75295556, -0.5187054 ,  0.75924635, -0.7862662 ,
          0.99981725, -0.80290836,  0.97651815,  0.99763787, -0.29619345,
         -0.1252967 ,  0.33606276, -0.65137684, -0.9680231 ,  0.77586985,
          0.22347753,  0.27245504, -0.07826214, -0.8383849 , -0.85373163,
          1.        , -0.4563588 , -0.91339815, -0.9999861 ,  0.66063935,
         -0.985843  , -0.7818757 , -0.7000497 , -0.6840764 ,  0.9995542 ,
          0.60819125,  0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
         -0.8771755 ,  0.71411085,  0.8113569 ,  0.9974196 , -0.75211936,
          0.63400257, -0.8272833 ,  0.99780786,  0.9965285 ,  0.59551436,
         -0.9876875 , -0.04439292,  0.9939223 ,  0.9993717 , -0.9965501 ,
         -0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
         -0.95355797, -0.67561924,  0.9997761 , -0.85473967,  0.998495  ,
         -0.95756954,  0.633171  ,  0.4570475 , -0.5316367 , -0.9663824 ,
          0.9567106 , -0.45497724,  0.12964879,  0.9964744 , -0.9711668 ,
          0.69636106, -0.9178346 ,  0.8313186 ,  0.69686604,  0.8141587 ,
         -0.33600506,  0.94798595,  0.8800869 ,  0.15029034, -0.91185665,
          0.6322724 , -0.9971475 ,  0.71948224,  0.9695236 ,  0.84242374,
          0.99995124,  0.5982563 , -0.98341423,  0.61301434,  0.9997318 ,
         -0.9981808 , -0.65651804, -0.8484874 , -0.9961815 ,  0.9030814 ,
          0.87141925,  0.8897381 , -0.92870414,  0.07134341,  0.8739935 ,
          0.91630197, -0.9465984 , -0.59741104, -1.        ,  0.9989559 ,
          0.99991184,  0.67439264,  0.92025673, -0.60730827,  0.8362061 ,
          1.        , -0.70801497,  0.9883806 , -0.9984141 ,  0.9919259 ,
         -0.998869  ,  0.9976203 ,  0.9888036 ,  0.8556838 , -0.9722744 ,
         -0.99810714,  0.8182833 ,  0.98808485,  0.6643728 ,  0.99212515,
         -0.99988   ,  0.26405996,  0.93139845,  0.99021816,  0.6846886 ,
          0.9986462 ,  0.92254627, -0.6406982 ], dtype=float32)),
 ('The acting was a bit lacking',
  array([0.9921152 , 0.00788479], dtype=float32),
  array([-0.00791603, -4.842819  ], dtype=float32),
  'Negative',
  array([ 0.67417824,  0.8235167 ,  0.99999565, -0.8565971 , -0.99499583,
          0.8219966 , -0.9185583 , -0.5234593 ,  0.99962074,  0.99999714,
          0.9507927 , -0.9996754 ,  0.22211392, -0.99826247,  0.7562492 ,
          0.93803996,  0.82738185,  0.4773049 , -0.73478544,  0.85207295,


来源:https://stackoverflow.com/questions/60767089/bert-get-sentence-level-embedding-after-fine-tuning

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!