Rnn 19 (attention and seq2seq learning using addition dataset in pytorch)
Seqence to Seqence Learning을 이용한 덧셈 규칙 학습
import os
import random
import string
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import pickle
from copy import deepcopy
from sklearn.model_selection import train_test_split
# 시드값 고정
seed = 50
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed) # 파이썬 난수 생성기 시드 고정
np.random.seed(seed) # 넘파이 난수 생성기 시드 고정
torch.manual_seed(seed) # 파이토치 난수 생성기 시드 고정 (CPU 사용 시)
torch.cuda.manual_seed(seed) # 파이토치 난수 생성기 시드 고정 (GPU 사용 시)
torch.cuda.manual_seed_all(seed) # 파이토치 난수 생성기 시드 고정 (멀티GPU 사용 시)
torch.backends.cudnn.deterministic = True # 확정적 연산 사용
torch.backends.cudnn.benchmark = False # 벤치마크 기능 해제
torch.backends.cudnn.enabled = False # cudnn 사용 해제
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
1. 데이터 다운로드
from google.colab import files
# addition.txt 업로드 하기
t = files.upload()
Saving addition.txt to addition.txt
2. 데이터 불러오기
Vocabulary
class Vocabulary():
def __init__(self, file_path, vocab_from_file, vocab_file='./vocab.pkl'):
# dictionary 초기화
self.char2idx = {}
self.idx2char = {}
self.idx = 0
questions, answers = [], []
for line in open(file_path, 'r'):
idx = line.find('_')
questions.append(line[:idx])
answers.append(line[idx:-1])
self.questions, self.answers = questions, answers
if vocab_from_file:
with open(vocab_file, 'rb') as f:
vocab = pickle.load(f)
self.char2idx = vocab.char2idx
self.idx2char = vocab.idx2char
print('Vocabulary successfully loaded from vocab.pkl file!')
else:
self.build_vocab()
with open(vocab_file, 'wb') as f:
pickle.dump(self, f)
def build_vocab(self):
for i in range(len(self.questions)):
question, answer = self.questions[i], self.answers[i]
self.add_char(question)
self.add_char(answer)
print('Vocabulary 길이 : ', len(self.char2idx))
def add_char(self, txt):
chars = list(txt) # ['1', '6', '+', '7', '5', ' ', ' ']
for i, char in enumerate(chars):
if char not in self.char2idx:
tmp_id = len(self.char2idx)
self.char2idx[char] = tmp_id
self.idx2char[tmp_id] = char
def __len__(self):
return len(self.char2idx)
# Note. Dataset 클래스에서 Vocabulary()를 사용하지만,
# 편의상 전역변수로 준비해 놓으면 Dataset 클래스와 상관없이 활용
file_path='./addition.txt'
questions, answers = [], []
for line in open(file_path, 'r'):
idx = line.find('_')
questions.append(line[:idx])
answers.append(line[idx:-1])
vocab = Vocabulary(file_path, vocab_from_file=False)
Vocabulary 길이 : 13
Dataset
class AdditionDataset(Dataset):
def __init__(self, file_path, questions, answers, vocab_from_file, vocab_file='./vocab.pkl'):
vocab = Vocabulary(file_path, vocab_from_file, vocab_file)
self.questions, self.answers = questions, answers
self.x = []
self.t = []
for i, question in enumerate(self.questions):
self.x.append([vocab.char2idx[c] for c in list(question)])
for i, answer in enumerate(self.answers):
self.t.append([vocab.char2idx[c] for c in list(answer)])
def __getitem__(self, index):
#return torch.LongTensor(self.x[index]), torch.LongTensor(self.t[index])
return torch.LongTensor(self.x[index][::-1]), torch.LongTensor(self.t[index])
def __len__(self):
return len(self.t)
# train/valid/test 분리
train_indices, test_indices = train_test_split(range(len(questions)), test_size=0.1)
train_indices, valid_indices = train_test_split(range(len(train_indices)), test_size=0.1)
questions, answers = np.array(questions), np.array(answers)
questions_train, questions_valid, questions_test = questions[train_indices], questions[valid_indices], questions[test_indices]
answers_train, answers_valid, answers_test = answers[train_indices], answers[valid_indices], answers[test_indices]
questions_train.shape, answers_train.shape
((40500,), (40500,))
trainset = AdditionDataset(file_path, questions_train, answers_train, vocab_from_file=False)
validset = AdditionDataset(file_path, questions_valid, answers_valid, vocab_from_file=True)
testset = AdditionDataset(file_path, questions_test, answers_test, vocab_from_file=True)
Vocabulary 길이 : 13
Vocabulary successfully loaded from vocab.pkl file!
Vocabulary successfully loaded from vocab.pkl file!
trainset[0][0], trainset[0][1]
(tensor([5, 5, 0, 2, 1, 0, 1]), tensor([6, 1, 0, 3, 5]))
# 입력
for i in trainset[1][0]:
i = int(i)
char = vocab.idx2char[i]
print(char, end='')
print()
# 정답
for i in trainset[1][1]:
i = int(i)
char = vocab.idx2char[i]
print(char, end='')
634+982
_725
3. 데이터 적재 : DataLoader
batch_size = 128
trainloader = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
validloader = DataLoader(dataset=validset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=True)
batch = next(iter(trainloader))
batch[0].size(), batch[1].size()
(torch.Size([128, 7]), torch.Size([128, 5]))
batch = next(iter(validloader))
batch[0].size(), batch[1].size()
(torch.Size([128, 7]), torch.Size([128, 5]))
batch = next(iter(testloader))
batch[0].size(), batch[1].size()
(torch.Size([128, 7]), torch.Size([128, 5]))
len(trainloader), len(validloader), len(testloader)
(317, 36, 40)
5. 모델 생성: Seq2Seq
class AttentionEncoder(nn.Module):
def __init__(self, vocab_size, wordvec_size, hidden_size):
super().__init__()
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=wordvec_size)
self.lstm = nn.LSTM(input_size=wordvec_size, hidden_size=hidden_size, batch_first=True)
def forward(self, inputs): # input shape (N=128, T=7)
embed = self.embed(inputs) # embed shape (N=128, T=7, D=16)
out, (h, c)= self.lstm(embed) # out shape (N=128, T=7, H=128)
# h(c) shape (num_layers=1, N=128, H=128)
return out
class Attention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, enc_hs, dec_hs):
N, D_T, H = dec_hs.shape # N=128, D_T=4, H=128
_, E_T, _ = enc_hs.shape # N=128, E_T=7, H=128
out = torch.empty_like(dec_hs) # N=128, D_T=4, H=128
for t in range(D_T): # D_T : 4
# Attetion Weight
h = dec_hs[:, t, :] # h shape (N=128, H=128)
h = h.reshape(N, 1, H) # h shape (N=128, 1, H=128)
hr = h.repeat(1, E_T, 1) # hr shape (N=128, E_T=7, H=128)
t1 = enc_hs * hr # t1 shape (N=128, E_T=7, H=128)
s = torch.sum(t1, dim=2) # s shape (N=128, E_T=7)
a = torch.softmax(s, dim=1) # a shape (N=128, E_T=7)
# Weighted Sum
a = a.reshape(N, E_T, 1) # a shape (N=128, E_T=7, 1)
ar = a.repeat(1, 1, H) # ar shape (N=128, E_T=7, H=128)
t2 = enc_hs * ar # t2 shape (N=128, E_T=7, H=128)
c = torch.sum(t2, dim=1) # c shape (N=128, H=128)
out[:, t, :] = c
return out # out shape (N=128, D_T=4, H=128)
class AttentionDecoder(nn.Module):
def __init__(self, vocab_size, wordvec_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=wordvec_size)
self.lstm = nn.LSTM(input_size=wordvec_size, hidden_size=hidden_size, batch_first=True)
self.attention = Attention()
self.affine = nn.Linear(in_features=hidden_size+hidden_size, out_features=vocab_size)
def forward(self, inputs, enc_hs): # inputs shape (N=128, T=4)
# enc_hs shape (N=128, T=7, H=128)
N, T = inputs.shape
N, T, H = enc_hs.shape
h = enc_hs[:,-1].unsqueeze(0) # h shape (num_layers=1, N=128, H=128)
c = self.init_cell(N)
embed = self.embed(inputs) # embed shape (N=128, T=4, D=16)
dec_hs, _ = self.lstm(embed, (h, c)) # dec_hs shape (N=128, T=4, H=128)
context = self.attention(enc_hs, dec_hs) # enc_hs shape (N=128, T=7, H=128)
# dec_hs shape (N=128, T=4, H=128)
# context shape (N=128, D_T=4, H=128)
out = torch.cat((context, dec_hs), dim=2) # out shape (N=128, T=4, H+H=128+128)
out = self.affine(out) # out shape (N=128, T=4, V=13)
#out = out.view(-1, self.vocab_size) # out shape (NxT=128*4, V=13)
return out
def init_cell(self, batch_size):
weight = next(self.parameters())
return weight.new_zeros(1, batch_size, self.hidden_size)
def generate(self, enc_hs, start_id, sample_size):
sampled = []
sample_id = start_id
h = enc_hs[:, -1].unsqueeze(0) # h shape (num_layers=1, N=1, H=128)
c = self.init_cell(batch_size = 1) # c shape : (num_layers=1, N=1, H=128)
for _ in range(sample_size):
# sample_id = torch.tensor(sample_id).reshape(1, 1) # sample_id shape : (N=1, T=1)
sample_id = sample_id.clone().detach().reshape(1, 1) # remove userwarning
embed = self.embed(sample_id) # embed shape : (N=1, T=1, D=16)
dec_hs, (h, c)= self.lstm(embed, (h, c)) # out shape : (N=1, T=1, H=128)
context = self.attention(enc_hs, dec_hs)
out = torch.cat((context, dec_hs), dim=2) # out shape : (N=1, T=1, H+H=128+128 )
score = self.affine(out) # score shape : (N=1, T=1, V=13)
sample_id = torch.max(score, dim=2)[1]
sampled.append(int(sample_id))
return sampled
class AttentionSeq2Seq(nn.Module):
def __init__(self, vocab_size, wordvec_size, hidden_size):
super().__init__()
self.encoder = AttentionEncoder(vocab_size, wordvec_size, hidden_size)
self.decoder = AttentionDecoder(vocab_size, wordvec_size, hidden_size)
def forward(self, inputs, targets): # inputs shape (N=128, T=7)
# targets shape (N=128, T=5)
decoder_in = targets[:, :-1] # decoder_in shape (N=128, T=4)
h = self.encoder(inputs) # h shape (num_layers=1, N=128, H=128)
out = self.decoder(decoder_in, h) # out shape (N=128, T=4, V=13)
return out
def generate(self, inputs, start_id, sample_size): # inputs : (N=1, T=7)
h = self.encoder(inputs) # h shape (num_layers=1, N=1, H=128)
sampled = self.decoder.generate(h, start_id, sample_size) # start_id = 6('_'), sample_size=4
return sampled
하이퍼 파라미터 설정
vocab_size = len(vocab)
wordvec_size = 16
hidden_size = 128
batch_size = 128
learning_rate = 0.01
num_epochs=50
model = AttentionSeq2Seq(vocab_size=vocab_size,
wordvec_size=wordvec_size,
hidden_size=hidden_size)
model = model.to(device)
model
Seq2Seq(
(encoder): AttentionEncoder(
(embed): Embedding(13, 16)
(lstm): LSTM(16, 128, batch_first=True)
)
(decoder): AttentionDecoder(
(embed): Embedding(13, 16)
(lstm): LSTM(16, 128, batch_first=True)
(attention): Attention()
(affine): Linear(in_features=256, out_features=13, bias=True)
)
)
out = model(batch[0].to(device), batch[1].to(device))
out.shape
torch.Size([128, 4, 13])
6. 모델 설정 (손실함수, 옵티마이저 선택)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
mode='min', factor=0.4,
patience=3, verbose=True)
7. 모델 훈련
def validate(model, validloader, loss_fn):
model.eval()
total = 0
correct = 0
valid_loss = []
valid_epoch_loss=0
with torch.no_grad():
for batch_data in validloader:
inputs = batch_data[0].to(device)
targets = batch_data[1].to(device)
# 전방향 예측과 손실
optimizer.zero_grad()
logits = model(inputs, targets)
targets = targets[:, 1:].clone() # Decoder의 정답을 준비하기 위해 1번째부터 색인
loss = loss_fn(logits.view(-1, vocab_size), targets.view(-1))
valid_loss.append(loss.item())
valid_epoch_loss = np.mean(valid_loss)
total_loss["val"].append(valid_epoch_loss)
return valid_epoch_loss
def eval_seq2seq(model, question, correct, idx2char, verbose=False, is_reverse=True):
model.eval()
correct = correct.flatten()
# 머릿글자
start_id = correct[0]
correct = correct[1:]
guess = model.generate(question, start_id, len(correct))
# 문자열로 변환
question = ''.join([idx2char[int(c)] for c in question.flatten()])
correct = ''.join([idx2char[int(c)] for c in correct])
guess = ''.join([idx2char[int(c)] for c in guess])
if verbose :
if is_reverse:
question = question[::-1]
print('Question : ', question)
print('True : ', correct)
print('Guess : ', guess)
print()
return 1 if guess == correct else 0
def train_loop(model, trainloader, loss_fn, epochs, optimizer):
min_loss = 1000000
trigger = 0
patience = 5
max_grad = 5.0
for epoch in range(epochs):
model.train()
train_loss = []
for batch_data in (trainloader):
inputs = batch_data[0].to(device)
targets = batch_data[1].to(device)
optimizer.zero_grad()
logits = model(inputs, targets)
targets = targets[:, 1:].clone() # Decoder의 정답을 준비하기 위해 1번째부터 색인
loss = loss_fn(logits.view(-1, vocab_size), targets.view(-1))
loss.backward()
# clipping gradient
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad)
optimizer.step()
train_loss.append(loss.item())
train_epoch_loss = np.mean(train_loss)
total_loss["train"].append(train_epoch_loss)
valid_epoch_loss = validate(model, validloader, loss_fn)
# for valid accuracy (it takes time!!)
correct_num = 0
for i in range(len(validset)):
question = validset[i][0].unsqueeze(0).to(device)
correct = validset[i][1].unsqueeze(0).to(device)
correct_num += eval_seq2seq(model, question, correct, vocab.idx2char, verbose=False, is_reverse=True)
valid_accuracy = correct_num /len(validset)
print("Epoch: {}/{}, Train Loss={:.4f}, Val Loss={:.4f}, Val Accuracy={:.2f}".format(
epoch + 1, epochs,
total_loss["train"][-1],
total_loss["val"][-1],
valid_accuracy
))
# Early Stopping (조기 종료)
if valid_epoch_loss > min_loss: # valid_loss가 min_loss를 갱신하지 못하면
trigger += 1
print('trigger : ', trigger)
if trigger > patience:
print('Early Stopping !!!')
print('Training loop is finished !!')
return
else:
trigger = 0
min_loss = valid_epoch_loss
best_model_state = deepcopy(model.state_dict())
torch.save(best_model_state, 'best_checkpoint.pth')
# -------------------------------------------
# Learning Rate Scheduler
scheduler.step(valid_epoch_loss)
# -------------------------------------------
total_loss = {"train": [], "val": []}
%time train_loop(model, trainloader, loss_fn, num_epochs, optimizer)
Epoch: 1/100, Train Loss=1.3811, Val Loss=0.9915, Val Accuracy=0.05
Epoch: 2/100, Train Loss=0.8615, Val Loss=0.7612, Val Accuracy=0.10
Epoch: 3/100, Train Loss=0.6989, Val Loss=0.6416, Val Accuracy=0.15
Epoch: 4/100, Train Loss=0.6054, Val Loss=0.5912, Val Accuracy=0.17
Epoch: 5/100, Train Loss=0.5452, Val Loss=0.5182, Val Accuracy=0.22
Epoch: 6/100, Train Loss=0.5126, Val Loss=0.4928, Val Accuracy=0.23
Epoch: 7/100, Train Loss=0.4904, Val Loss=0.5027, Val Accuracy=0.22
trigger : 1
Epoch: 8/100, Train Loss=0.4591, Val Loss=0.4741, Val Accuracy=0.24
Epoch: 9/100, Train Loss=0.4374, Val Loss=0.4420, Val Accuracy=0.30
Epoch: 10/100, Train Loss=0.4195, Val Loss=0.4163, Val Accuracy=0.31
Epoch: 11/100, Train Loss=0.3960, Val Loss=0.3808, Val Accuracy=0.37
Epoch: 12/100, Train Loss=0.3339, Val Loss=0.2934, Val Accuracy=0.55
Epoch: 13/100, Train Loss=0.2567, Val Loss=0.2198, Val Accuracy=0.69
Epoch: 14/100, Train Loss=0.1784, Val Loss=0.1585, Val Accuracy=0.79
Epoch: 15/100, Train Loss=0.1204, Val Loss=0.1008, Val Accuracy=0.91
Epoch: 16/100, Train Loss=0.0685, Val Loss=0.0630, Val Accuracy=0.94
Epoch: 17/100, Train Loss=0.0494, Val Loss=0.0424, Val Accuracy=0.96
Epoch: 18/100, Train Loss=0.0332, Val Loss=0.0311, Val Accuracy=0.97
Epoch: 19/100, Train Loss=0.0274, Val Loss=0.0601, Val Accuracy=0.94
trigger : 1
Epoch: 20/100, Train Loss=0.0344, Val Loss=0.0288, Val Accuracy=0.97
Epoch: 21/100, Train Loss=0.0430, Val Loss=0.0330, Val Accuracy=0.97
trigger : 1
Epoch: 22/100, Train Loss=0.0225, Val Loss=0.0339, Val Accuracy=0.96
trigger : 2
Epoch: 23/100, Train Loss=0.0175, Val Loss=0.0263, Val Accuracy=0.98
Epoch: 24/100, Train Loss=0.0237, Val Loss=0.0205, Val Accuracy=0.98
Epoch: 25/100, Train Loss=0.0235, Val Loss=0.0445, Val Accuracy=0.95
trigger : 1
Epoch: 26/100, Train Loss=0.0269, Val Loss=0.0159, Val Accuracy=0.99
Epoch: 27/100, Train Loss=0.0185, Val Loss=0.0300, Val Accuracy=0.97
trigger : 1
Epoch: 28/100, Train Loss=0.0264, Val Loss=0.0228, Val Accuracy=0.97
trigger : 2
Epoch: 29/100, Train Loss=0.0150, Val Loss=0.0084, Val Accuracy=0.99
Epoch: 30/100, Train Loss=0.0237, Val Loss=0.0260, Val Accuracy=0.97
trigger : 1
Epoch: 31/100, Train Loss=0.0219, Val Loss=0.0295, Val Accuracy=0.97
trigger : 2
Epoch: 32/100, Train Loss=0.0193, Val Loss=0.0126, Val Accuracy=0.98
trigger : 3
Epoch: 33/100, Train Loss=0.0129, Val Loss=0.0402, Val Accuracy=0.96
trigger : 4
Epoch 00033: reducing learning rate of group 0 to 4.0000e-03.
Epoch: 34/100, Train Loss=0.0055, Val Loss=0.0036, Val Accuracy=1.00
Epoch: 35/100, Train Loss=0.0010, Val Loss=0.0023, Val Accuracy=1.00
Epoch: 36/100, Train Loss=0.0007, Val Loss=0.0020, Val Accuracy=1.00
Epoch: 37/100, Train Loss=0.0006, Val Loss=0.0018, Val Accuracy=1.00
Epoch: 38/100, Train Loss=0.0005, Val Loss=0.0017, Val Accuracy=1.00
Epoch: 39/100, Train Loss=0.0004, Val Loss=0.0016, Val Accuracy=1.00
Epoch: 40/100, Train Loss=0.0004, Val Loss=0.0015, Val Accuracy=1.00
Epoch: 41/100, Train Loss=0.0003, Val Loss=0.0014, Val Accuracy=1.00
Epoch: 42/100, Train Loss=0.0003, Val Loss=0.0013, Val Accuracy=1.00
Epoch: 43/100, Train Loss=0.0002, Val Loss=0.0012, Val Accuracy=1.00
Epoch: 44/100, Train Loss=0.0002, Val Loss=0.0011, Val Accuracy=1.00
Epoch: 45/100, Train Loss=0.0002, Val Loss=0.0011, Val Accuracy=1.00
Epoch: 46/100, Train Loss=0.0002, Val Loss=0.0010, Val Accuracy=1.00
Epoch: 47/100, Train Loss=0.0001, Val Loss=0.0010, Val Accuracy=1.00
Epoch: 48/100, Train Loss=0.0001, Val Loss=0.0009, Val Accuracy=1.00
Epoch: 49/100, Train Loss=0.0001, Val Loss=0.0008, Val Accuracy=1.00
Epoch: 50/100, Train Loss=0.0001, Val Loss=0.0008, Val Accuracy=1.00
Epoch: 51/100, Train Loss=0.0001, Val Loss=0.0008, Val Accuracy=1.00
trigger : 1
Epoch: 52/100, Train Loss=0.0001, Val Loss=0.0008, Val Accuracy=1.00
Epoch: 53/100, Train Loss=0.0001, Val Loss=0.0008, Val Accuracy=1.00
Epoch: 54/100, Train Loss=0.0001, Val Loss=0.0007, Val Accuracy=1.00
Epoch: 55/100, Train Loss=0.0000, Val Loss=0.0007, Val Accuracy=1.00
Epoch: 56/100, Train Loss=0.0000, Val Loss=0.0007, Val Accuracy=1.00
trigger : 1
Epoch: 57/100, Train Loss=0.0000, Val Loss=0.0007, Val Accuracy=1.00
trigger : 2
Epoch: 58/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 3
Epoch: 59/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 4
Epoch 00059: reducing learning rate of group 0 to 1.6000e-03.
Epoch: 60/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 5
Epoch: 61/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 6
Epoch: 62/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 7
Epoch: 63/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 8
Epoch 00063: reducing learning rate of group 0 to 6.4000e-04.
Epoch: 64/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 9
Epoch: 65/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 10
Epoch: 66/100, Train Loss=0.0000, Val Loss=0.0008, Val Accuracy=1.00
trigger : 11
Early Stopping !!!
Training loop is finished !!
CPU times: user 22min 19s, sys: 3.35 s, total: 22min 22s
Wall time: 22min 31s
import matplotlib.pyplot as plt
plt.plot(total_loss['train'], label="train_loss")
plt.plot(total_loss['val'], label="vallid_loss")
plt.legend()
plt.show()
8. 모델 평가
# test 데이터 전체로 평가
correct_num = 0
for i in range(len(testset)):
question = testset[i][0].unsqueeze(0).to(device)
correct = testset[i][1].unsqueeze(0).to(device)
correct_num += eval_seq2seq(model, question, correct, vocab.idx2char, verbose=False, is_reverse=True)
test_accuracy = correct_num / len(testset)
test_accuracy # before Attention : 99.42
0.9998
9. 모델 예측
for i in range(0, 3):
question = testset[i][0].unsqueeze(0).to(device)
correct = testset[i][1].unsqueeze(0).to(device)
eval_seq2seq(model, question, correct, vocab.idx2char, verbose=True, is_reverse=True)
Question : 31+648
True : 679
Guess : 679
Question : 744+531
True : 1275
Guess : 1275
Question : 7+909
True : 916
Guess : 916
댓글남기기