11 분 소요

RNN 모델을 사용하여 AG News Category 예측하기

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import pickle
import nltk
nltk.download('punkt')

import string
from collections import Counter
from copy import deepcopy
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
# 시드값 고정
seed = 50
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)                # 파이썬 난수 생성기 시드 고정
np.random.seed(seed)             # 넘파이 난수 생성기 시드 고정
torch.manual_seed(seed)          # 파이토치 난수 생성기 시드 고정 (CPU 사용 시)
torch.cuda.manual_seed(seed)     # 파이토치 난수 생성기 시드 고정 (GPU 사용 시)
torch.cuda.manual_seed_all(seed) # 파이토치 난수 생성기 시드 고정 (멀티GPU 사용 시)
torch.backends.cudnn.deterministic = True # 확정적 연산 사용
torch.backends.cudnn.benchmark = False    # 벤치마크 기능 해제
torch.backends.cudnn.enabled = False      # cudnn 사용 해제
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cpu')

1. 데이터 다운로드

# kaggle api를 사용할 수 있는 패키지 설치
!pip install kaggle

# kaggle.json upload
from google.colab import files
files.upload()

# permmision warning 방지
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# download
!kaggle datasets download -d amananandrai/ag-news-classification-dataset

!mkdir dataset
# unzip(압축풀기)
!unzip -q ag-news-classification-dataset.zip -d dataset/
Saving kaggle.json to kaggle.json
Downloading ag-news-classification-dataset.zip to /content
 79% 9.00M/11.4M [00:00<00:00, 83.3MB/s]
100% 11.4M/11.4M [00:00<00:00, 97.5MB/s]

2. 데이터 불러오기

Vocabulary

class Vocabulary():    
    def __init__(self, vocab_threshold, vocab_file,
                 mask_word="<mask>",
                 start_word="<start>",
                 end_word="<end>",
                 unk_word="<unk>",
                 news_df=None, vocab_from_file=False):
        
        self.vocab_threshold = vocab_threshold
        # train과 valid로 나귀기전 전체 데이터(train_news_df)
        self.news_df = news_df

        # dictionary 초기화화
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
                
        if vocab_from_file:
            # 파일로부터 읽기
            with open(vocab_file, 'rb') as f:
                vocab = pickle.load(f)
                self.word2idx = vocab.word2idx
                self.idx2word = vocab.idx2word
            print('Vocabulary succesfully loaded from vocab.pkl file!')
        else:
            self.build_vocab()
            with open(vocab_file, 'wb') as f:
                pickle.dump(self, f)

    def build_vocab(self):
        # mask_word (0), start_word(1), end_word (2), unk_word (3)
        self.mask_index = self.add_word(mask_word) # 0
        self.begin_seq_index = self.add_word(start_word) # 1
        self.end_seq_index = self.add_word(end_word) # 2
        self.unk_index = self.add_word(unk_word) # 3
        self.add_description()

    def add_word(self, word):
        if not word in self.word2idx:
            idx = self.idx
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
        return idx

    def add_description(self):
        counter = Counter()
        for description in self.news_df.Description:
            tokens = nltk.tokenize.word_tokenize(description.lower())
            counter.update(tokens)

        for word, cnt in counter.items():
            if cnt >= self.vocab_threshold:
                self.add_word(word)

        print("description_vocab 길이 :", len(self.word2idx))      

    def __call__(self, word): # lookup word
        return self.word2idx.get(word, self.unk_index)

    def __len__(self):
        return len(self.word2idx)
#data_dir = '/kaggle/input/ag-news-classification-dataset/'
data_dir = './dataset/'
train_news_csv= data_dir + "train.csv"
test_news_csv= data_dir + "test.csv"
train_news_df = pd.read_csv(train_news_csv)
test_news_df = pd.read_csv(test_news_csv)
len(train_news_df)
120000
from sklearn.model_selection import train_test_split

train_indices, valid_indices = train_test_split(range(len(train_news_df)), 
                                                stratify= train_news_df['Class Index'], 
                                                test_size=0.2)
len(train_indices), len(valid_indices)
(96000, 24000)
train_df = train_news_df.iloc[train_indices]
valid_df = train_news_df.iloc[valid_indices]

클래스별 분포

train_news_df['Class Index'].value_counts()/len(train_news_df)
3    0.25
4    0.25
2    0.25
1    0.25
Name: Class Index, dtype: float64
valid_df['Class Index'].value_counts()/len(valid_df)
3    0.25
2    0.25
4    0.25
1    0.25
Name: Class Index, dtype: float64
# Consists of class ids 1-4 where 1-World, 2-Sports, 3-Business, 4-Sci/Tech
category_map = {1:"World", 2:"Sports", 3:"Business", 4:"Sci/Tech"}

Dataset

vocab_threshold = 25
mask_word = "<mask>"
start_word = "<start>"
end_word = "<end>"
unk_word = "<unk>"
vocab_file = './vocab.pkl'
vocab_from_file = False

vocab = Vocabulary(vocab_threshold, vocab_file, 
                    mask_word, start_word, end_word, unk_word, 
                    train_news_df, vocab_from_file)
description_vocab 길이 : 10320
def vectorize(text, vector_length = -1):
    # 입력 : 'Clijsters Unsure About Latest Injury, Says Hewitt'
    # 출력 : [1, 2, 4, 9, 10, 9, 2, 0, 0, 0, 0]
    # vocabulary 에서 text의 각 단어들의 id를 가져올 수 있도록

    indices = [vocab.begin_seq_index]
    word_list = nltk.tokenize.word_tokenize(text.lower())
    for word in word_list:
        indices.append(vocab(word))
    indices.append(vocab.end_seq_index)

    if vector_length < 0:
        vector_length = len(indices)

    out_vector = np.zeros(vector_length, dtype=np.int64)
    out_vector[:len(indices)] = indices
    out_vector[len(indices):] = vocab.mask_index

    return out_vector    
vectorize("I am a boy", -1)
array([   1,  464, 7647,   21, 4123,    2])

class NewsDataset(Dataset):
  def __init__(self, mode, batch_size, vocab_threshold, vocab_file,
               mask_word, start_word, end_word, unk_word,
               news_df, vocab_from_file):
      self.news_df = news_df
      self.batch_size = batch_size
      self.description_vocab = Vocabulary(vocab_threshold, vocab_file,
                                    mask_word, start_word, end_word, unk_word,
                                    train_news_df, vocab_from_file)
      # (1) 문자열을 max_length 로 고정해서 벡터화할 때
      # measure_len = lambda context : len(nltk.tokenize.word_tokenize(context.lower()))
      # self.max_seq_length = max(map(measure_len, train_news_df.Title)) + 2
      
      # (2) 문자열을 가변적으로 벡터화할 때
      self.description_lengths = [len(nltk.tokenize.word_tokenize(description.lower()))  
                                         for description in self.news_df.Description]

  def __getitem__(self, index):
      row = self.news_df.iloc[index]
      description_vector = vectorize(row.Description, -1)
      category_index = row['Class Index'] - 1
      return {'x_data' : description_vector, 
              'y_target' : category_index }

  def __len__(self):
      return len(self.news_df)

  def get_train_indices(self):
      # 전체 데이터에서 description의 길이 중 하나를 선택해서
      # 그 길이와 같은 description들의 indices를 반환
      sel_length = np.random.choice(self.description_lengths)
      condition = [self.description_lengths[i] == sel_length for i in np.arange(len(self.description_lengths))]
      all_indices = np.where(condition)[0]
      indices = list(np.random.choice(all_indices, size=self.batch_size))
      return indices

train_batch_size = 32
valid_batch_size = 32
test_batch_size = 32
vocab_threshold = 25
mask_word = "<mask>"
start_word = "<start>"
end_word = "<end>"
unk_word = "<unk>"
vocab_file = './vocab.pkl'
vocab_from_file = False

trainset = NewsDataset("train", train_batch_size, vocab_threshold, vocab_file,
                       mask_word, start_word, end_word, unk_word,
                       train_df, vocab_from_file)
description_vocab 길이 : 10320
vocab_from_file = True
validset = NewsDataset("valid", valid_batch_size, vocab_threshold, vocab_file,
                       mask_word, start_word, end_word, unk_word,
                       valid_df, vocab_from_file)
testset = NewsDataset("test", test_batch_size, vocab_threshold, vocab_file,
                       mask_word, start_word, end_word, unk_word,
                       test_news_df, vocab_from_file)
Vocabulary succesfully loaded from vocab.pkl file!
Vocabulary succesfully loaded from vocab.pkl file!
len(trainset), len(validset), len(testset)
(96000, 24000, 7600)

3. 데이터 적재 : DataLoader

indices = trainset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=train_batch_size, drop_last=False)
trainloader = DataLoader(dataset=trainset, batch_sampler=batch_sampler, num_workers=2)
# !lscpu
indices = validset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=valid_batch_size, drop_last=False)
validloader = DataLoader(dataset=validset, batch_sampler=batch_sampler, num_workers=2)
indices = testset.get_train_indices() # description의 길이가 같은 indices (batch_size)
initial_sampler = data.sampler.SubsetRandomSampler(indices = indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=test_batch_size, drop_last=False)
testloader = DataLoader(dataset=testset, batch_sampler=batch_sampler, num_workers=2)
batch = next(iter(trainloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 32]), torch.Size([32]))
batch = next(iter(validloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 41]), torch.Size([32]))
batch = next(iter(testloader))
batch['x_data'].size(), batch['y_target'].size()
(torch.Size([32, 38]), torch.Size([32]))
len(trainloader), len(validloader), len(testloader)
(1, 1, 1)

5. 모델 생성: NewsClassifier

image.png

class NewsClassifier(nn.Module):
    def __init__(self, embedding_size, vocab_size,  
                 rnn_hidden_dim, num_classes, dropout_p):
        
        super().__init__()

        self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.rnn = nn.RNN(input_size = embedding_size, hidden_size = rnn_hidden_dim, 
                          batch_first = True)
        self.classifier = nn.Linear(in_features=rnn_hidden_dim, out_features = num_classes)
 
            
    def forward(self, inputs, apply_softmax=False): # input : (batch_size, description_length)
        embeddings = self.emb(inputs) # embeddings : (batch_size, description_length, embedding_size)
        _, hidden = self.rnn(embeddings) # outputs : (batch_size, description_length, rnn_hidden_dim)
                                               # hidden : (num_layers , batch_size, rnn_hidden_dim)
        hidden = hidden[0] # hidden : (batch_size, rnn_hidden_dim)
        outputs = self.classifier(hidden) # outputs : (batch_size, num_classes)

        if apply_softmax:
            outputs = F.softmax(outputs, dim=1)
        return outputs   

하이퍼 파라미터 설정

embedding_size=100
rnn_hidden_dim = 100
learning_rate=0.001
num_epochs=12
classifier = NewsClassifier(embedding_size=embedding_size, 
                            vocab_size=len(vocab),    
                            rnn_hidden_dim = rnn_hidden_dim,                      
                            num_classes=4, dropout_p=None)
classifier = classifier.to(device)
classifier
NewsClassifier(
  (emb): Embedding(10320, 100)
  (rnn): RNN(100, 100, batch_first=True)
  (classifier): Linear(in_features=100, out_features=4, bias=True)
)
out = classifier(batch['x_data'].to(device))
out.shape
torch.Size([32, 4])

6. 모델 설정 (손실함수, 옵티마이저 선택)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                           mode='min', factor=0.01,
                                           patience=1, verbose=True)

7. 모델 훈련

train_step = len(trainset) // train_batch_size
valid_step = len(validset) // valid_batch_size
test_step = len(testset) // test_batch_size
train_step, valid_step, test_step
(3000, 750, 237)
def validate(model, validloader, loss_fn):
    model.eval()
    total = 0   
    correct = 0
    valid_loss = []
    valid_epoch_loss=0
    valid_accuracy = 0

    with torch.no_grad():
        for step in range(1, valid_step+1):        
            indices = validset.get_train_indices()
            initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
                                                    batch_size=valid_batch_size,
                                                    drop_last=False)
            validloader= data.DataLoader(dataset=validset, num_workers=0,
                                         batch_sampler=batch_sampler)
            # Obtain the batch.
            batch_dict = next(iter(validloader))
            inputs = batch_dict['x_data'].to(device)
            labels = batch_dict['y_target'].to(device)

            # 전방향 예측과 손실
            logits = model(inputs)
            loss = loss_fn(logits, labels)
            valid_loss.append(loss.item())
            
            # 정확도
            _, preds = torch.max(logits, 1) # 배치에 대한 최종 예측
            # preds = logit.max(dim=1)[1] 
            correct += int((preds == labels).sum()) # 배치 중 맞은 것의 개수가 correct에 누적
            total += labels.shape[0] # 배치 사이즈만큼씩 total에 누적

    valid_epoch_loss = np.mean(valid_loss)
    total_loss["val"].append(valid_epoch_loss)
    valid_accuracy = correct / total

    return valid_epoch_loss, valid_accuracy
def train_loop(model, trainloader, loss_fn, epochs, optimizer):  
    min_loss = 1000000  
    trigger = 0
    patience = 3     

    for epoch in range(epochs):
        model.train()
        train_loss = []

        for step in range(1, train_step+1): 
            indices = trainset.get_train_indices()
            initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
                                                    batch_size=train_batch_size,
                                                    drop_last=False)
            trainloader= data.DataLoader(dataset=trainset, num_workers=2,
                                         batch_sampler = batch_sampler)
            
            # Obtain the batch.
            batch_dict = next(iter(trainloader))
            inputs = batch_dict['x_data'].to(device)
            labels = batch_dict['y_target'].to(device)


            optimizer.zero_grad()
            logits = model(inputs)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        train_epoch_loss = np.mean(train_loss)
        total_loss["train"].append(train_epoch_loss)
 
        valid_epoch_loss, valid_accuracy = validate(model, validloader, loss_fn)

        print("Epoch: {}/{}, Train Loss={:.4f}, Val Loss={:.4f}, Val Accyracy={:.4f}".format(
                epoch + 1, epochs,
                total_loss["train"][-1],
                total_loss["val"][-1],
                valid_accuracy))  
               
        # Early Stopping (조기 종료)
        if valid_epoch_loss > min_loss: # valid_loss가 min_loss를 갱신하지 못하면
          trigger += 1
          print('trigger : ', trigger)
          if trigger > patience:
            print('Early Stopping !!!')
            print('Training loop is finished !!')
            return
        else:
          trigger = 0
          min_loss = valid_epoch_loss
          best_model_state = deepcopy(model.state_dict())          
          torch.save(best_model_state, 'best_checkpoint.pth') 
        # -------------------------------------------

        # Learning Rate Scheduler
        scheduler.step(valid_epoch_loss)
        # -------------------------------------------                      
total_loss = {"train": [], "val": []}
%time train_loop(classifier, trainloader, loss_fn, num_epochs, optimizer)
import matplotlib.pyplot as plt

plt.plot(total_loss['train'], label="train_loss")
plt.plot(total_loss['val'], label="vallid_loss")
plt.legend()
plt.show()

png

8. 모델 평가

def evaluate(model, testloader, loss_fn):
    model.eval()
    total = 0   
    correct = 0
    test_loss = []
    test_epoch_loss=0
    test_accuracy = 0

    with torch.no_grad():
        for step in range(1, test_step+1):        
            indices = testset.get_train_indices()
            initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
                                                    batch_size=test_batch_size,
                                                    drop_last=False)
            testloader= data.DataLoader(dataset=testset, num_workers=2,
                                        batch_sampler=batch_sampler)

            # Obtain the batch.
            batch_dict = next(iter(testloader))
            inputs = batch_dict['x_data'].to(device)
            labels = batch_dict['y_target'].to(device)

            # 전방향 예측과 손실
            logits = model(inputs)
            loss = loss_fn(logits, labels)
            test_loss.append(loss.item())
            
            # 정확도
            _, preds = torch.max(logits, 1) # 배치에 대한 최종 예측
            # preds = logit.max(dim=1)[1] 
            correct += int((preds == labels).sum()) # 배치 중 맞은 것의 개수가 correct에 누적
            total += labels.shape[0] # 배치 사이즈만큼씩 total에 누적

    test_epoch_loss = np.mean(test_loss)
    # total_loss["val"].append(test_epoch_loss)
    test_accuracy = correct / total

    print('Test Loss : {:.5f}'.format(test_epoch_loss), 
        'Test Accuracy : {:.5f}'.format(test_accuracy))


evaluate(classifier, testloader, loss_fn)  
Test Loss : 0.41626 Test Accuracy : 0.85539

9. 모델 예측

def predict_category(text, classifier, max_length):
    # 뉴스 제목을 기반으로 카테고리를 예측
    
    # 1. vetororize
    vectorized_text = vectorize(text, vector_length=max_length)
    vectorized_text = torch.tensor(vectorized_text).unsqueeze(0) # tensor로 바꾸고, 배치처리를 위해 차원 늘림

    # 2. model의 예측
    result = classifier(vectorized_text, apply_softmax=True) # result : 예측 확률
    probability, index= result.max(dim=1)
    predict = index.item() + 1 # 0번 클래스 예측은 실제 데이터 에서는 1번 클래스와 같다.
    probability = probability.item()
    preidct_category = category_map[predict]
    
    return {'category':preidct_category, 'probability':probability}    
    
def get_samples():
    # True Category 기반 샘플 얻어오기
    # 클래스 별로 5개씩 샘플을 준비
    samples = {}

    for category in testset.news_df['Class Index'].unique(): # 1=>2=>3=>4
        samples[category]= testset.news_df.Description[testset.news_df['Class Index'] == category].tolist()[-5:]
    
    return samples

test_samples = get_samples() 
# Consists of class ids 1-4 where 1-World, 2-Sports, 3-Business, 4-Sci/Tech
category_map = {1:"World", 2:"Sports", 3:"Business", 4:"Sci/Tech"}
classifier = classifier.to('cpu')
for truth, sample_group in test_samples.items():
    print(f"True Category: {category_map[truth]}")
    print('='*50)
    for sample in sample_group:
        prediction = predict_category(sample, classifier, testset.max_seq_length)
        print("예측: {} (p={:0.2f})".format(prediction['category'], prediction['probability']))
        print("샘플: {}".format(sample))
        print('-'*30)
    print()
True Category: Business
==================================================
예측: Business (p=0.82)
샘플: Russia shrugs off US court freeze on oil giant Yukos auction
------------------------------
예측: Business (p=1.00)
샘플: Airbus chief wins fight to take controls at Eads
------------------------------
예측: Sci/Tech (p=0.55)
샘플: EBay #39;s Buy Of Rent.com May Lack Strategic Sense
------------------------------
예측: Business (p=1.00)
샘플: 5 of arthritis patients in Singapore take Bextra or Celebrex &lt;b&gt;...&lt;/b&gt;
------------------------------
예측: Sci/Tech (p=0.58)
샘플: EBay gets into rentals
------------------------------

True Category: Sci/Tech
==================================================
예측: Sci/Tech (p=0.76)
샘플: Microsoft buy comes with strings attached
------------------------------
예측: Sci/Tech (p=0.94)
샘플: U.S. Army aims to halt paperwork with IBM system
------------------------------
예측: Sci/Tech (p=0.93)
샘플: Analysis: PeopleSoft users speak out about Oracle takeover (InfoWorld)
------------------------------
예측: Sci/Tech (p=0.96)
샘플: Hobbit-finding Boffins in science top 10
------------------------------
예측: Sci/Tech (p=0.86)
샘플: Search providers seek video, find challenges
------------------------------

True Category: Sports
==================================================
예측: World (p=0.71)
샘플: The Newest Hope ; Marriage of Necessity Just Might Work Out
------------------------------
예측: Business (p=0.79)
샘플: Saban hiring on hold
------------------------------
예측: Sports (p=0.72)
샘플: Mortaza strikes to lead superb Bangladesh rally
------------------------------
예측: Sports (p=0.53)
샘플: Void is filled with Clement
------------------------------
예측: Sports (p=0.99)
샘플: Martinez leaves bitter
------------------------------

True Category: World
==================================================
예측: Business (p=0.50)
샘플: Pricey Drug Trials Turn Up Few New Blockbusters
------------------------------
예측: World (p=1.00)
샘플: Bosnian-Serb prime minister resigns in protest against U.S. sanctions (Canadian Press)
------------------------------
예측: Business (p=0.68)
샘플: Historic Turkey-EU deal welcomed
------------------------------
예측: World (p=1.00)
샘플: Powell pushes diplomacy for N. Korea
------------------------------
예측: World (p=0.41)
샘플: Around the world
------------------------------

참고문법

Counter

from collections import Counter
# 사용 예 (1)
s = 'life is short, so python is easy.'

counter = Counter(s)
counter
# 사용 예 (2)
s = 'life is short, so python is easy.'

counter = Counter()
tokens = s.split()
for token in tokens:
    counter[token] += 1
counter    
Counter({'life': 1, 'is': 2, 'short,': 1, 'so': 1, 'python': 1, 'easy.': 1})
# 사용 예 (3)
s = 'life is short, so python is easy.'

counter = Counter()
tokens = s.split()
counter.update(tokens)
counter 
Counter({'life': 1, 'is': 2, 'short,': 1, 'so': 1, 'python': 1, 'easy.': 1})
# 사용 예 (4)
s = 'life is short, so python is easy.'

counter = Counter()
tokens = nltk.tokenize.word_tokenize(s)
counter.update(tokens)
counter 
Counter({'life': 1,
         'is': 2,
         'short': 1,
         ',': 1,
         'so': 1,
         'python': 1,
         'easy': 1,
         '.': 1})
# 아래 문자열에 대해 소문자로 변환전 tokenize 결과후 변환후 결과가 다름
s = "AP - Environmentalists asked the U.S. Fish and Wildlife Service on Wednesday to grant protected status to the California spotted owl, claiming the bird's old-growth forest habitat is threatened by logging."
counter = Counter()
tokens = nltk.tokenize.word_tokenize(s)
counter.update(tokens)
counter 
s = "AP - Environmentalists asked the U.S. Fish and Wildlife Service on Wednesday to grant protected status to the California spotted owl, claiming the bird's old-growth forest habitat is threatened by logging."
counter = Counter()
tokens = nltk.tokenize.word_tokenize(s.lower())
counter.update(tokens)
counter 

np.where

a = np.arange(10)
cond = a < 5
cond
array([ True,  True,  True,  True,  True, False, False, False, False,
       False])
np.where(cond, a, a*10)
array([ 0,  1,  2,  3,  4, 50, 60, 70, 80, 90])
np.where(cond) # condition만 적으면 아래 np.asarray(cond).nonzero()와 동일한 결과과
(array([0, 1, 2, 3, 4]),)
np.asarray(cond).nonzero()
(array([0, 1, 2, 3, 4]),)
description_lengths = [37, 38, 45, 2, 3, 37, 37, 45, 45, 50]
sel_length = 37
cond = [description_lengths[i] == sel_length for i in np.arange(len(description_lengths))]
indices = np.where(cond)
indices 
(array([0, 5, 6]),)

BatchSampler

indices = range(10)
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=3, drop_last=False)
list(batch_sampler)
[[9, 2, 0], [6, 7, 5], [3, 8, 4], [1]]
indices = range(32) # 같은 길이인 description들의 indices
initial_sampler = data.sampler.SubsetRandomSampler(indices=indices) # random하게 뒤섞음
batch_sampler = data.sampler.BatchSampler(sampler=initial_sampler, batch_size=32, drop_last=True) # initial sampler에서 샘플링된 데이터를 배치 단위로 만들어줌
list(batch_sampler)

Reference

댓글남기기