Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
159 views
in Technique[技术] by (71.8m points)

python - An error occurs when predict with the same data as when performing train (expects 3 input(s), but it received 75 input tensors.)

After training the model, I tried to make predictions, but an error occurred and I don't know how to fix it.

The model was constructed using electra.

here is my model

electra = TFElectraModel.from_pretrained("monologg/koelectra-base-v3-discriminator", from_pt=True)
input_ids = tf.keras.Input(shape=(MAX_LEN,), name='input_ids', dtype=tf.int32)
mask = tf.keras.Input(shape=(MAX_LEN,), name='attention_mask', dtype=tf.int32)
token = tf.keras.Input(shape=(MAX_LEN,), name='token_type_ids', dtype=tf.int32)
embeddings = electra(input_ids, attention_mask = mask, token_type_ids= token)[0]
X = tf.keras.layers.GlobalMaxPool1D()(embeddings)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dense(128, activation='relu')(X)
X = tf.keras.layers.Dropout(0.1)(X)
y = tf.keras.layers.Dense(3, activation='softmax', name='outputs')(X)
model = tf.keras.Model(inputs=[input_ids, mask, token], outputs=y)
model.layers[2].trainable=False
model.summary()

and here is summary

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          [(None, 25)]         0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 25)]         0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 25)]         0                                            
__________________________________________________________________________________________________
tf_electra_model_4 (TFElectraMo TFBaseModelOutput(la 112330752   input_ids[0][0]                  
                                                                 attention_mask[0][0]             
                                                                 token_type_ids[0][0]             
__________________________________________________________________________________________________
global_max_pooling1d_6 (GlobalM (None, 768)          0           tf_electra_model_4[3][0]         
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 768)          3072        global_max_pooling1d_6[0][0]     
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 128)          98432       batch_normalization_7[0][0]      
__________________________________________________________________________________________________
dropout_390 (Dropout)           (None, 128)          0           dense_18[0][0]                   
__________________________________________________________________________________________________
outputs (Dense)                 (None, 3)            387         dropout_390[0][0]                
==================================================================================================
Total params: 112,432,643
Trainable params: 112,431,107
Non-trainable params: 1,536
__________________________________________________________________________________________________

This is the code to make train data set.

input_ids = []
attention_masks = []
token_type_ids = []
train_data_labels = []

for train_sent, train_label in tqdm(zip(train_data["content"], train_data["label"]), total=len(train_data)):
    try:
        input_id, attention_mask, token_type_id = Electra_tokenizer(train_sent, MAX_LEN)
        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        token_type_ids.append(token_type_id)
        train_data_labels.append(train_label)

    except Exception as e:
        print(e)
        print(train_sent)
        pass

train_input_ids = np.array(input_ids, dtype=int)
train_attention_masks = np.array(attention_masks, dtype=int)
train_type_ids = np.array(token_type_ids, dtype=int)
intent_train_inputs = (train_input_ids, train_attention_masks, train_type_ids)
intent_train_data_labels = np.asarray(train_data_labels, dtype=np.int32)

this is train data set shape

tf.Tensor([ 3 75 25], shape=(3,), dtype=int32)

With this train data, the model train works fine but execute the following code to predict, an error occurs.

sample_text = 'this is sample text'
input_id, attention_mask, token_type_id = Electra_tokenizer(sample_text, MAX_LEN)
sample_text = (input_id, attention_mask, token_type_id)
model(sample_text) #or model.predict(sample_text)

here is error

Layer model_15 expects 3 input(s), but it received 75 input tensors. Inputs received: [<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: ....

It's the same shape as when i train, but why do i get an error and ask for help on how to fix it.

hope you have a great year ahead. Happy New Year.

question from:https://stackoverflow.com/questions/65517232/an-error-occurs-when-predict-with-the-same-data-as-when-performing-train-expect

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

It was a tensor dimension problem.

test_input_ids = np.array(test_input_ids, dtype=np.int32)
test_attention_mask = np.array(test_attention_mask, dtype=np.int32)
test_token_type_id = np.array(test_token_type_id, dtype=np.int32)
ids = np.expand_dims(test_input_ids, axis=0)
atm = np.expand_dims(test_attention_mask, axis=0)
tok = np.expand_dims(test_token_type_id, axis=0)
model(ids,atm.tok) works fine

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...