Rnn 9 (predict news category using rnn in pytorch)
RNN 모델을 사용하여 AG News Category 예측하기
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import pickle
import nltk
nltk.download('punkt')
import string
from collections import Counter
from copy import deepcopy
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data] Package punkt is already up-to-date!
# 시드값 고정
seed = 50
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed) # 파이썬 난수 생성기 시드 고정
np.random.seed(seed) # 넘파이 난수 생성기 시드 고정
torch.manual_seed(seed) # 파이토치 난수 생성기 시드 고정 (CPU 사용 시)
torch.cuda.manual_seed(seed) # 파이토치 난수 생성기 시드 고정 (GPU 사용 시)
torch.cuda.manual_seed_all(seed) # 파이토치 난수 생성기 시드 고정 (멀티GPU 사용 시)
torch.backends.cudnn.deterministic = True # 확정적 연산 사용
torch.backends.cudnn.benchmark = False # 벤치마크 기능 해제
torch.backends.cudnn.enabled = False # cudnn 사용 해제
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cpu')
1. 데이터 다운로드
# kaggle api를 사용할 수 있는 패키지 설치
!pip install kaggle
# kaggle.json upload
from google.colab import files
files.upload()
# permmision warning 방지
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
# download
!kaggle datasets download -d amananandrai/ag-news-classification-dataset
!mkdir dataset
# unzip(압축풀기)
!unzip -q ag-news-classification-dataset.zip -d dataset/
Saving kaggle.json to kaggle.json
Downloading ag-news-classification-dataset.zip to /content
79% 9.00M/11.4M [00:00<00:00, 83.3MB/s]
100% 11.4M/11.4M [00:00<00:00, 97.5MB/s]
2. 데이터 불러오기
Vocabulary
class Vocabulary():
def __init__(self, vocab_threshold, vocab_file,
mask_word="<mask>",
start_word="<start>",
end_word="<end>",
unk_word="<unk>",
news_df=None, vocab_from_file=False):
self.vocab_threshold = vocab_threshold
# train과 valid로 나귀기전 전체 데이터(train_news_df)
self.news_df = news_df
# dictionary 초기화화
self.word2idx = {}
self.idx2word = {}
self.idx = 0
if vocab_from_file:
# 파일로부터 읽기
with open(vocab_file, 'rb') as f:
vocab = pickle.load(f)
self.word2idx = vocab.word2idx
self.idx2word = vocab.idx2word
print('Vocabulary succesfully loaded from vocab.pkl file!')
else:
self.build_vocab()
with open(vocab_file, 'wb') as f:
pickle.dump(self, f)
def build_vocab(self):
# mask_word (0), start_word(1), end_word (2), unk_word (3)
self.mask_index = self.add_word(mask_word) # 0
self.begin_seq_index = self.add_word(start_word) # 1
self.end_seq_index = self.add_word(end_word) # 2
self.unk_index = self.add_word(unk_word) # 3
self.add_description()
def add_word(self, word):
if not word in self.word2idx:
idx = self.idx
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
return idx
def add_description(self):
counter = Counter()
for description in self.news_df.Description:
tokens = nltk.tokenize.word_tokenize(description.lower())
counter.update(tokens)
for word, cnt in counter.items():
if cnt >= self.vocab_threshold:
self.add_word(word)
print("description_vocab 길이 :", len(self.word2idx))
def __call__(self, word): # lookup word
return self.word2idx.get(word, self.unk_index)
def __len__(self):
return len(self.word2idx)
#data_dir = '/kaggle/input/ag-news-classification-dataset/'
data_dir = './dataset/'
train_news_csv= data_dir + "train.csv"
test_news_csv= data_dir + "test.csv"
train_news_df = pd.read_csv(train_news_csv)
test_news_df = pd.read_csv(test_news_csv)
len(train_news_df)
120000
from sklearn.model_selection import train_test_split
train_indices, valid_indices = train_test_split(range(len(train_news_df)),
stratify= train_news_df['Class Index'],
test_size=0.2)
len(train_indices), len(valid_indices)
(96000, 24000)
train_df = train_news_df.iloc[train_indices]
valid_df = train_news_df.iloc[valid_indices]
클래스별 분포
train_news_df['Class Index'].value_counts()/len(train_news_df)
3 0.25
4 0.25
2 0.25
1 0.25
Name: Class Index, dtype: float64
valid_df['Class Index'].value_counts()/len(valid_df)
3 0.25
2 0.25
4 0.25
1 0.25
Name: Class Index, dtype: float64
# Consists of class ids 1-4 where 1-World, 2-Sports, 3-Business, 4-Sci/Tech
category_map = {1:"World", 2:"Sports", 3:"Business", 4:"Sci/Tech"}
Dataset
vocab_threshold = 25
mask_word = "<mask>"
start_word = "<start>"
end_word = "<end>"
unk_word = "<unk>"
vocab_file = './vocab.pkl'
vocab_from_file = False
vocab = Vocabulary(vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
train_news_df, vocab_from_file)
description_vocab 길이 : 10320
def vectorize(text, vector_length = -1):
# 입력 : 'Clijsters Unsure About Latest Injury, Says Hewitt'
# 출력 : [1, 2, 4, 9, 10, 9, 2, 0, 0, 0, 0]
# vocabulary 에서 text의 각 단어들의 id를 가져올 수 있도록
indices = [vocab.begin_seq_index]
word_list = nltk.tokenize.word_tokenize(text.lower())
for word in word_list:
indices.append(vocab(word))
indices.append(vocab.end_seq_index)
if vector_length < 0:
vector_length = len(indices)
out_vector = np.zeros(vector_length, dtype=np.int64)
out_vector[:len(indices)] = indices
out_vector[len(indices):] = vocab.mask_index
return out_vector
vectorize("I am a boy", -1)
array([ 1, 464, 7647, 21, 4123, 2])
class NewsDataset(Dataset):
def __init__(self, mode, batch_size, vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
news_df, vocab_from_file):
self.news_df = news_df
self.batch_size = batch_size
self.description_vocab = Vocabulary(vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
train_news_df, vocab_from_file)
# (1) 문자열을 max_length 로 고정해서 벡터화할 때
# measure_len = lambda context : len(nltk.tokenize.word_tokenize(context.lower()))
# self.max_seq_length = max(map(measure_len, train_news_df.Title)) + 2
# (2) 문자열을 가변적으로 벡터화할 때
self.description_lengths = [len(nltk.tokenize.word_tokenize(description.lower()))
for description in self.news_df.Description]
def __getitem__(self, index):
row = self.news_df.iloc[index]
description_vector = vectorize(row.Description, -1)
category_index = row['Class Index'] - 1
return {'x_data' : description_vector,
'y_target' : category_index }
def __len__(self):
return len(self.news_df)
def get_train_indices(self):
# 전체 데이터에서 description의 길이 중 하나를 선택해서
# 그 길이와 같은 description들의 indices를 반환
sel_length = np.random.choice(self.description_lengths)
condition = [self.description_lengths[i] == sel_length for i in np.arange(len(self.description_lengths))]
all_indices = np.where(condition)[0]
indices = list(np.random.choice(all_indices, size=self.batch_size))
return indices
train_batch_size = 32
valid_batch_size = 32
test_batch_size = 32
vocab_threshold = 25
mask_word = "<mask>"
start_word = "<start>"
end_word = "<end>"
unk_word = "<unk>"
vocab_file = './vocab.pkl'
vocab_from_file = False
trainset = NewsDataset("train", train_batch_size, vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
train_df, vocab_from_file)
description_vocab 길이 : 10320
vocab_from_file = True
validset = NewsDataset("valid", valid_batch_size, vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
valid_df, vocab_from_file)
testset = NewsDataset("test", test_batch_size, vocab_threshold, vocab_file,
mask_word, start_word, end_word, unk_word,
test_news_df, vocab_from_file)
Vocabulary succesfully loaded from vocab.pkl file!
Vocabulary succesfully loaded from vocab.pkl file!
len(trainset), len(validset), len(testset)
(96000, 24000, 7600)
3. 데이터 적재 : DataLoader
indices = trainset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=train_batch_size, drop_last=False)
trainloader = DataLoader(dataset=trainset, batch_sampler=batch_sampler, num_workers=2)
# !lscpu
indices = validset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=valid_batch_size, drop_last=False)
validloader = DataLoader(dataset=validset, batch_sampler=batch_sampler, num_workers=2)
indices = testset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=test_batch_size, drop_last=False)
testloader = DataLoader(dataset=testset, batch_sampler=batch_sampler, num_workers=2)
batch = next(iter(trainloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 32]), torch.Size([32]))
batch = next(iter(validloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 41]), torch.Size([32]))
batch = next(iter(testloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 38]), torch.Size([32]))
len(trainloader), len(validloader), len(testloader)
(1, 1, 1)
5. 모델 생성: NewsClassifier
class NewsClassifier(nn.Module):
def __init__(self, embedding_size, vocab_size,
rnn_hidden_dim, num_classes, dropout_p):
super().__init__()
self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
self.rnn = nn.RNN(input_size = embedding_size, hidden_size = rnn_hidden_dim,
batch_first = True)
self.classifier = nn.Linear(in_features=rnn_hidden_dim, out_features = num_classes)
def forward(self, inputs, apply_softmax=False): # input : (batch_size, description_length)
embeddings = self.emb(inputs) # embeddings : (batch_size, description_length, embedding_size)
_, hidden = self.rnn(embeddings) # outputs : (batch_size, description_length, rnn_hidden_dim)
# hidden : (num_layers , batch_size, rnn_hidden_dim)
hidden = hidden[0] # hidden : (batch_size, rnn_hidden_dim)
outputs = self.classifier(hidden) # outputs : (batch_size, num_classes)
if apply_softmax:
outputs = F.softmax(outputs, dim=1)
return outputs
하이퍼 파라미터 설정
embedding_size=100
rnn_hidden_dim = 100
learning_rate=0.001
num_epochs=12
classifier = NewsClassifier(embedding_size=embedding_size,
vocab_size=len(vocab),
rnn_hidden_dim = rnn_hidden_dim,
num_classes=4, dropout_p=None)
classifier = classifier.to(device)
classifier
NewsClassifier(
(emb): Embedding(10320, 100)
(rnn): RNN(100, 100, batch_first=True)
(classifier): Linear(in_features=100, out_features=4, bias=True)
)
out = classifier(batch['x_data'].to(device))
out.shape
torch.Size([32, 4])
6. 모델 설정 (손실함수, 옵티마이저 선택)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
mode='min', factor=0.01,
patience=1, verbose=True)
7. 모델 훈련
train_step = len(trainset) // train_batch_size
valid_step = len(validset) // valid_batch_size
test_step = len(testset) // test_batch_size
train_step, valid_step, test_step
(3000, 750, 237)
def validate(model, validloader, loss_fn):
model.eval()
total = 0
correct = 0
valid_loss = []
valid_epoch_loss=0
valid_accuracy = 0
with torch.no_grad():
for step in range(1, valid_step+1):
indices = validset.get_train_indices()
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
batch_size=valid_batch_size,
drop_last=False)
validloader= data.DataLoader(dataset=validset, num_workers=0,
batch_sampler=batch_sampler)
# Obtain the batch.
batch_dict = next(iter(validloader))
inputs = batch_dict['x_data'].to(device)
labels = batch_dict['y_target'].to(device)
# 전방향 예측과 손실
logits = model(inputs)
loss = loss_fn(logits, labels)
valid_loss.append(loss.item())
# 정확도
_, preds = torch.max(logits, 1) # 배치에 대한 최종 예측
# preds = logit.max(dim=1)[1]
correct += int((preds == labels).sum()) # 배치 중 맞은 것의 개수가 correct에 누적
total += labels.shape[0] # 배치 사이즈만큼씩 total에 누적
valid_epoch_loss = np.mean(valid_loss)
total_loss["val"].append(valid_epoch_loss)
valid_accuracy = correct / total
return valid_epoch_loss, valid_accuracy
def train_loop(model, trainloader, loss_fn, epochs, optimizer):
min_loss = 1000000
trigger = 0
patience = 3
for epoch in range(epochs):
model.train()
train_loss = []
for step in range(1, train_step+1):
indices = trainset.get_train_indices()
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
batch_size=train_batch_size,
drop_last=False)
trainloader= data.DataLoader(dataset=trainset, num_workers=2,
batch_sampler = batch_sampler)
# Obtain the batch.
batch_dict = next(iter(trainloader))
inputs = batch_dict['x_data'].to(device)
labels = batch_dict['y_target'].to(device)
optimizer.zero_grad()
logits = model(inputs)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_epoch_loss = np.mean(train_loss)
total_loss["train"].append(train_epoch_loss)
valid_epoch_loss, valid_accuracy = validate(model, validloader, loss_fn)
print("Epoch: {}/{}, Train Loss={:.4f}, Val Loss={:.4f}, Val Accyracy={:.4f}".format(
epoch + 1, epochs,
total_loss["train"][-1],
total_loss["val"][-1],
valid_accuracy))
# Early Stopping (조기 종료)
if valid_epoch_loss > min_loss: # valid_loss가 min_loss를 갱신하지 못하면
trigger += 1
print('trigger : ', trigger)
if trigger > patience:
print('Early Stopping !!!')
print('Training loop is finished !!')
return
else:
trigger = 0
min_loss = valid_epoch_loss
best_model_state = deepcopy(model.state_dict())
torch.save(best_model_state, 'best_checkpoint.pth')
# -------------------------------------------
# Learning Rate Scheduler
scheduler.step(valid_epoch_loss)
# -------------------------------------------
total_loss = {"train": [], "val": []}
%time train_loop(classifier, trainloader, loss_fn, num_epochs, optimizer)
import matplotlib.pyplot as plt
plt.plot(total_loss['train'], label="train_loss")
plt.plot(total_loss['val'], label="vallid_loss")
plt.legend()
plt.show()
8. 모델 평가
def evaluate(model, testloader, loss_fn):
model.eval()
total = 0
correct = 0
test_loss = []
test_epoch_loss=0
test_accuracy = 0
with torch.no_grad():
for step in range(1, test_step+1):
indices = testset.get_train_indices()
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
batch_size=test_batch_size,
drop_last=False)
testloader= data.DataLoader(dataset=testset, num_workers=2,
batch_sampler=batch_sampler)
# Obtain the batch.
batch_dict = next(iter(testloader))
inputs = batch_dict['x_data'].to(device)
labels = batch_dict['y_target'].to(device)
# 전방향 예측과 손실
logits = model(inputs)
loss = loss_fn(logits, labels)
test_loss.append(loss.item())
# 정확도
_, preds = torch.max(logits, 1) # 배치에 대한 최종 예측
# preds = logit.max(dim=1)[1]
correct += int((preds == labels).sum()) # 배치 중 맞은 것의 개수가 correct에 누적
total += labels.shape[0] # 배치 사이즈만큼씩 total에 누적
test_epoch_loss = np.mean(test_loss)
# total_loss["val"].append(test_epoch_loss)
test_accuracy = correct / total
print('Test Loss : {:.5f}'.format(test_epoch_loss),
'Test Accuracy : {:.5f}'.format(test_accuracy))
evaluate(classifier, testloader, loss_fn)
Test Loss : 0.41626 Test Accuracy : 0.85539
9. 모델 예측
def predict_category(text, classifier, max_length):
# 뉴스 제목을 기반으로 카테고리를 예측
# 1. vetororize
vectorized_text = vectorize(text, vector_length=max_length)
vectorized_text = torch.tensor(vectorized_text).unsqueeze(0) # tensor로 바꾸고, 배치처리를 위해 차원 늘림
# 2. model의 예측
result = classifier(vectorized_text, apply_softmax=True) # result : 예측 확률
probability, index= result.max(dim=1)
predict = index.item() + 1 # 0번 클래스 예측은 실제 데이터 에서는 1번 클래스와 같다.
probability = probability.item()
preidct_category = category_map[predict]
return {'category':preidct_category, 'probability':probability}
def get_samples():
# True Category 기반 샘플 얻어오기
# 클래스 별로 5개씩 샘플을 준비
samples = {}
for category in testset.news_df['Class Index'].unique(): # 1=>2=>3=>4
samples[category]= testset.news_df.Description[testset.news_df['Class Index'] == category].tolist()[-5:]
return samples
test_samples = get_samples()
# Consists of class ids 1-4 where 1-World, 2-Sports, 3-Business, 4-Sci/Tech
category_map = {1:"World", 2:"Sports", 3:"Business", 4:"Sci/Tech"}
classifier = classifier.to('cpu')
for truth, sample_group in test_samples.items():
print(f"True Category: {category_map[truth]}")
print('='*50)
for sample in sample_group:
prediction = predict_category(sample, classifier, testset.max_seq_length)
print("예측: {} (p={:0.2f})".format(prediction['category'], prediction['probability']))
print("샘플: {}".format(sample))
print('-'*30)
print()
True Category: Business
==================================================
예측: Business (p=0.82)
샘플: Russia shrugs off US court freeze on oil giant Yukos auction
------------------------------
예측: Business (p=1.00)
샘플: Airbus chief wins fight to take controls at Eads
------------------------------
예측: Sci/Tech (p=0.55)
샘플: EBay #39;s Buy Of Rent.com May Lack Strategic Sense
------------------------------
예측: Business (p=1.00)
샘플: 5 of arthritis patients in Singapore take Bextra or Celebrex <b>...</b>
------------------------------
예측: Sci/Tech (p=0.58)
샘플: EBay gets into rentals
------------------------------
True Category: Sci/Tech
==================================================
예측: Sci/Tech (p=0.76)
샘플: Microsoft buy comes with strings attached
------------------------------
예측: Sci/Tech (p=0.94)
샘플: U.S. Army aims to halt paperwork with IBM system
------------------------------
예측: Sci/Tech (p=0.93)
샘플: Analysis: PeopleSoft users speak out about Oracle takeover (InfoWorld)
------------------------------
예측: Sci/Tech (p=0.96)
샘플: Hobbit-finding Boffins in science top 10
------------------------------
예측: Sci/Tech (p=0.86)
샘플: Search providers seek video, find challenges
------------------------------
True Category: Sports
==================================================
예측: World (p=0.71)
샘플: The Newest Hope ; Marriage of Necessity Just Might Work Out
------------------------------
예측: Business (p=0.79)
샘플: Saban hiring on hold
------------------------------
예측: Sports (p=0.72)
샘플: Mortaza strikes to lead superb Bangladesh rally
------------------------------
예측: Sports (p=0.53)
샘플: Void is filled with Clement
------------------------------
예측: Sports (p=0.99)
샘플: Martinez leaves bitter
------------------------------
True Category: World
==================================================
예측: Business (p=0.50)
샘플: Pricey Drug Trials Turn Up Few New Blockbusters
------------------------------
예측: World (p=1.00)
샘플: Bosnian-Serb prime minister resigns in protest against U.S. sanctions (Canadian Press)
------------------------------
예측: Business (p=0.68)
샘플: Historic Turkey-EU deal welcomed
------------------------------
예측: World (p=1.00)
샘플: Powell pushes diplomacy for N. Korea
------------------------------
예측: World (p=0.41)
샘플: Around the world
------------------------------
참고문법
Counter
from collections import Counter
# 사용 예 (1)
s = 'life is short, so python is easy.'
counter = Counter(s)
counter
# 사용 예 (2)
s = 'life is short, so python is easy.'
counter = Counter()
tokens = s.split()
for token in tokens:
counter[token] += 1
counter
Counter({'life': 1, 'is': 2, 'short,': 1, 'so': 1, 'python': 1, 'easy.': 1})
# 사용 예 (3)
s = 'life is short, so python is easy.'
counter = Counter()
tokens = s.split()
counter.update(tokens)
counter
Counter({'life': 1, 'is': 2, 'short,': 1, 'so': 1, 'python': 1, 'easy.': 1})
# 사용 예 (4)
s = 'life is short, so python is easy.'
counter = Counter()
tokens = nltk.tokenize.word_tokenize(s)
counter.update(tokens)
counter
Counter({'life': 1,
'is': 2,
'short': 1,
',': 1,
'so': 1,
'python': 1,
'easy': 1,
'.': 1})
# 아래 문자열에 대해 소문자로 변환전 tokenize 결과후 변환후 결과가 다름
s = "AP - Environmentalists asked the U.S. Fish and Wildlife Service on Wednesday to grant protected status to the California spotted owl, claiming the bird's old-growth forest habitat is threatened by logging."
counter = Counter()
tokens = nltk.tokenize.word_tokenize(s)
counter.update(tokens)
counter
s = "AP - Environmentalists asked the U.S. Fish and Wildlife Service on Wednesday to grant protected status to the California spotted owl, claiming the bird's old-growth forest habitat is threatened by logging."
counter = Counter()
tokens = nltk.tokenize.word_tokenize(s.lower())
counter.update(tokens)
counter
np.where
a = np.arange(10)
cond = a < 5
cond
array([ True, True, True, True, True, False, False, False, False,
False])
np.where(cond, a, a*10)
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
np.where(cond) # condition만 적으면 아래 np.asarray(cond).nonzero()와 동일한 결과과
(array([0, 1, 2, 3, 4]),)
np.asarray(cond).nonzero()
(array([0, 1, 2, 3, 4]),)
description_lengths = [37, 38, 45, 2, 3, 37, 37, 45, 45, 50]
sel_length = 37
cond = [description_lengths[i] == sel_length for i in np.arange(len(description_lengths))]
indices = np.where(cond)
indices
(array([0, 5, 6]),)
BatchSampler
indices = range(10)
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=3, drop_last=False)
list(batch_sampler)
[[9, 2, 0], [6, 7, 5], [3, 8, 4], [1]]
indices = range(32) # 같은 길이인 description들의 indices
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=32, drop_last=True) # initial sampler에서 샘플링된 데이터를 배치 단위로 만들어줌
list(batch_sampler)
댓글남기기