10 분 소요

CBOW Negative Sampling with PTB Dataset

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

1. 데이터 다운로드

!mkdir ptb_dataset
!wget https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt -P ./ptb_dataset
--2023-04-25 01:52:20--  https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5101618 (4.9M) [text/plain]
Saving to: ‘./ptb_dataset/ptb.train.txt’

ptb.train.txt       100%[===================>]   4.87M  --.-KB/s    in 0.03s   

2023-04-25 01:52:21 (145 MB/s) - ‘./ptb_dataset/ptb.train.txt’ saved [5101618/5101618]

2. 데이터 불러오기

dataset_dir = './ptb_dataset/'
train_file_name = 'ptb.train.txt'
def create_contexts_target(corpus, window_size=1):
    target = corpus[window_size:-window_size]
    contexts = []

    for idx in range(window_size, len(corpus)-window_size):
        cs = []
        for t in range(-window_size, window_size + 1):
            if t == 0:
                continue
            cs.append(corpus[idx + t])
        contexts.append(cs)

    return np.array(contexts), np.array(target)
import collections
GPU = True
class UnigramSampler:
    def __init__(self, corpus, power, sample_size):
        self.sample_size = sample_size
        self.vocab_size = None
        self.word_p = None

        counts = collections.Counter()
        for word_id in corpus:
            counts[word_id] += 1

        vocab_size = len(counts)
        self.vocab_size = vocab_size

        self.word_p = np.zeros(vocab_size)
        for i in range(vocab_size):
            self.word_p[i] = counts[i]

        self.word_p = np.power(self.word_p, power)
        self.word_p /= np.sum(self.word_p)

    def get_negative_sample(self, target):
        batch_size = target.shape[0]

        if not GPU:
            negative_sample = np.zeros((batch_size, self.sample_size), dtype=np.int32)

            for i in range(batch_size):
                p = self.word_p.copy()
                target_idx = target[i]
                p[target_idx] = 0
                p /= p.sum()
                negative_sample[i, :] = np.random.choice(self.vocab_size, size=self.sample_size, replace=False, p=p)
        else:
            # GPU(cupy)로 계산할 때는 속도를 우선한다.
            # 부정적 예에 타깃이 포함될 수 있다.
            negative_sample = np.random.choice(self.vocab_size, size=(batch_size, self.sample_size),
                                               replace=True, p=self.word_p)

        return negative_sample
class PTBDataset(Dataset):
    def __init__(self, file_path, window_size, sample_size):
        self.file_path = file_path
        self.word_to_id, self.id_to_word, self.words = self.load_vocab()

        corpus = np.array([self.word_to_id[w] for w in self.words])
        print('corpus size :', len(corpus))

        self.contexts, self.target = create_contexts_target(corpus, window_size)
        print('context.shpape:', self.contexts.shape, 'target.shape:', self.target.shape)
        
        self.sampler = UnigramSampler(corpus, 0.75, sample_size)

    def load_vocab(self):
        words = open(file_path).read().replace('\n', '<eos>').strip().split()
        word_to_id = {}
        id_to_word = {}

        for i, word in enumerate(words):
            if word not in word_to_id:
                new_id = len(word_to_id)
                word_to_id[word] = new_id
                id_to_word[new_id] = word
        print('vocab size:', len(id_to_word))        
        return word_to_id, id_to_word, words

    def __len__(self):
        return len(self.target)

    def __getitem__(self, index):
        return self.contexts[index], self.target[index]

학습 후 유사도 측정을 위해 준비

dataset_dir = './ptb_dataset/'
train_file_name = 'ptb.train.txt'
file_path = dataset_dir + train_file_name
words = open(file_path).read().replace('\n', '<eos>').strip().split()
words # 말뭉치에 있는 데이터를 단어 단위로 쪼개어 리스트로 보관
['aer',
 'banknote',
 'berlitz',
 'calloway',
 'centrust',
 'cluett',
 'fromstein',
 'gitano',
 'guterman',
 'hydro-quebec',
 'ipo',
 'kia',
 'memotec',
 'mlx',
 'nahb',
 'punts',
 'rake',
 'regatta',
 'rubens',
 'sim',
 'snack-food',
 'ssangyong',
 'swapo',
 'wachter',
 '<eos>',
 'pierre',
 '<unk>',
 'N',
 'years',
 'old',
 'will',
 'join',
 'the',
 'board',
 'as',
 'a',
 'nonexecutive',
 'director',
 'nov.',
 'N',
 '<eos>',
 'mr.',
 '<unk>',
 'is',
 'chairman',
 'of',
 '<unk>',
 'n.v.',
 'the',
 'dutch',
 'publishing',
 'group',
 '<eos>',
 'rudolph',
 '<unk>',
 'N',
 'years',
 'old',
 'and',
 'former',
 'chairman',
 'of',
 'consolidated',
 'gold',
 'fields',
 'plc',
 'was',
 'named',
 'a',
 'nonexecutive',
 'director',
 'of',
 'this',
 'british',
 'industrial',
 'conglomerate',
 '<eos>',
 'a',
 'form',
 'of',
 'asbestos',
 'once',
 'used',
 'to',
 'make',
 'kent',
 'cigarette',
 'filters',
 'has',
 'caused',
 'a',
 'high',
 'percentage',
 'of',
 'cancer',
 'deaths',
 'among',
 'a',
 'group',
 'of',
 'workers',
 'exposed',
 'to',
 'it',
 'more',
 'than',
 'N',
 'years',
 'ago',
 'researchers',
 'reported',
 '<eos>',
 'the',
 'asbestos',
 'fiber',
 '<unk>',
 'is',
 'unusually',
 '<unk>',
 'once',
 'it',
 'enters',
 'the',
 '<unk>',
 'with',
 'even',
 'brief',
 'exposures',
 'to',
 'it',
 'causing',
 'symptoms',
 'that',
 'show',
 'up',
 'decades',
 'later',
 'researchers',
 'said',
 '<eos>',
 '<unk>',
 'inc.',
 'the',
 'unit',
 'of',
 'new',
 'york-based',
 '<unk>',
 'corp.',
 'that',
 'makes',
 'kent',
 'cigarettes',
 'stopped',
 'using',
 '<unk>',
 'in',
 'its',
 '<unk>',
 'cigarette',
 'filters',
 'in',
 'N',
 '<eos>',
 'although',
 'preliminary',
 'findings',
 'were',
 'reported',
 'more',
 'than',
 'a',
 'year',
 'ago',
 'the',
 'latest',
 'results',
 'appear',
 'in',
 'today',
 "'s",
 'new',
 'england',
 'journal',
 'of',
 'medicine',
 'a',
 'forum',
 'likely',
 'to',
 'bring',
 'new',
 'attention',
 'to',
 'the',
 'problem',
 '<eos>',
 'a',
 '<unk>',
 '<unk>',
 'said',
 'this',
 'is',
 'an',
 'old',
 'story',
 '<eos>',
 'we',
 "'re",
 'talking',
 'about',
 'years',
 'ago',
 'before',
 'anyone',
 'heard',
 'of',
 'asbestos',
 'having',
 'any',
 'questionable',
 'properties',
 '<eos>',
 'there',
 'is',
 'no',
 'asbestos',
 'in',
 'our',
 'products',
 'now',
 '<eos>',
 'neither',
 '<unk>',
 'nor',
 'the',
 'researchers',
 'who',
 'studied',
 'the',
 'workers',
 'were',
 'aware',
 'of',
 'any',
 'research',
 'on',
 'smokers',
 'of',
 'the',
 'kent',
 'cigarettes',
 '<eos>',
 'we',
 'have',
 'no',
 'useful',
 'information',
 'on',
 'whether',
 'users',
 'are',
 'at',
 'risk',
 'said',
 'james',
 'a.',
 '<unk>',
 'of',
 'boston',
 "'s",
 '<unk>',
 'cancer',
 'institute',
 '<eos>',
 'dr.',
 '<unk>',
 'led',
 'a',
 'team',
 'of',
 'researchers',
 'from',
 'the',
 'national',
 'cancer',
 'institute',
 'and',
 'the',
 'medical',
 'schools',
 'of',
 'harvard',
 'university',
 'and',
 'boston',
 'university',
 '<eos>',
 'the',
 '<unk>',
 'spokeswoman',
 'said',
 'asbestos',
 'was',
 'used',
 'in',
 'very',
 'modest',
 'amounts',
 'in',
 'making',
 'paper',
 'for',
 'the',
 'filters',
 'in',
 'the',
 'early',
 '1950s',
 'and',
 'replaced',
 'with',
 'a',
 'different',
 'type',
 'of',
 '<unk>',
 'in',
 'N',
 '<eos>',
 'from',
 'N',
 'to',
 'N',
 'N',
 'billion',
 'kent',
 'cigarettes',
 'with',
 'the',
 'filters',
 'were',
 'sold',
 'the',
 'company',
 'said',
 '<eos>',
 'among',
 'N',
 'men',
 'who',
 'worked',
 'closely',
 'with',
 'the',
 'substance',
 'N',
 'have',
 'died',
 'more',
 'than',
 'three',
 'times',
 'the',
 'expected',
 'number',
 '<eos>',
 'four',
 'of',
 'the',
 'five',
 'surviving',
 'workers',
 'have',
 '<unk>',
 'diseases',
 'including',
 'three',
 'with',
 'recently',
 '<unk>',
 'cancer',
 '<eos>',
 'the',
 'total',
 'of',
 'N',
 'deaths',
 'from',
 'malignant',
 '<unk>',
 'lung',
 'cancer',
 'and',
 '<unk>',
 'was',
 'far',
 'higher',
 'than',
 'expected',
 'the',
 'researchers',
 'said',
 '<eos>',
 'the',
 '<unk>',
 'rate',
 'is',
 'a',
 'striking',
 'finding',
 'among',
 'those',
 'of',
 'us',
 'who',
 'study',
 '<unk>',
 'diseases',
 'said',
 'dr.',
 '<unk>',
 '<eos>',
 'the',
 'percentage',
 'of',
 'lung',
 'cancer',
 'deaths',
 'among',
 'the',
 'workers',
 'at',
 'the',
 'west',
 '<unk>',
 'mass.',
 'paper',
 'factory',
 'appears',
 'to',
 'be',
 'the',
 'highest',
 'for',
 'any',
 'asbestos',
 'workers',
 'studied',
 'in',
 'western',
 'industrialized',
 'countries',
 'he',
 'said',
 '<eos>',
 'the',
 'plant',
 'which',
 'is',
 'owned',
 'by',
 '<unk>',
 '&',
 '<unk>',
 'co.',
 'was',
 'under',
 'contract',
 'with',
 '<unk>',
 'to',
 'make',
 'the',
 'cigarette',
 'filters',
 '<eos>',
 'the',
 'finding',
 'probably',
 'will',
 'support',
 'those',
 'who',
 'argue',
 'that',
 'the',
 'u.s.',
 'should',
 'regulate',
 'the',
 'class',
 'of',
 'asbestos',
 'including',
 '<unk>',
 'more',
 '<unk>',
 'than',
 'the',
 'common',
 'kind',
 'of',
 'asbestos',
 '<unk>',
 'found',
 'in',
 'most',
 'schools',
 'and',
 'other',
 'buildings',
 'dr.',
 '<unk>',
 'said',
 '<eos>',
 'the',
 'u.s.',
 'is',
 'one',
 'of',
 'the',
 'few',
 'industrialized',
 'nations',
 'that',
 'does',
 "n't",
 'have',
 'a',
 'higher',
 'standard',
 'of',
 'regulation',
 'for',
 'the',
 'smooth',
 '<unk>',
 'fibers',
 'such',
 'as',
 '<unk>',
 'that',
 'are',
 'classified',
 'as',
 '<unk>',
 'according',
 'to',
 '<unk>',
 't.',
 '<unk>',
 'a',
 'professor',
 'of',
 '<unk>',
 'at',
 'the',
 'university',
 'of',
 'vermont',
 'college',
 'of',
 'medicine',
 '<eos>',
 'more',
 'common',
 '<unk>',
 'fibers',
 'are',
 '<unk>',
 'and',
 'are',
 'more',
 'easily',
 'rejected',
 'by',
 'the',
 'body',
 'dr.',
 '<unk>',
 'explained',
 '<eos>',
 'in',
 'july',
 'the',
 'environmental',
 'protection',
 'agency',
 'imposed',
 'a',
 'gradual',
 'ban',
 'on',
 'virtually',
 'all',
 'uses',
 'of',
 'asbestos',
 '<eos>',
 'by',
 'N',
 'almost',
 'all',
 'remaining',
 'uses',
 'of',
 '<unk>',
 'asbestos',
 'will',
 'be',
 'outlawed',
 '<eos>',
 'about',
 'N',
 'workers',
 'at',
 'a',
 'factory',
 'that',
 'made',
 'paper',
 'for',
 'the',
 'kent',
 'filters',
 'were',
 'exposed',
 'to',
 'asbestos',
 'in',
 'the',
 '1950s',
 '<eos>',
 'areas',
 'of',
 'the',
 'factory',
 'were',
 'particularly',
 'dusty',
 'where',
 'the',
 '<unk>',
 'was',
 'used',
 '<eos>',
 'workers',
 'dumped',
 'large',
 '<unk>',
 '<unk>',
 'of',
 'the',
 'imported',
 'material',
 'into',
 'a',
 'huge',
 '<unk>',
 'poured',
 'in',
 'cotton',
 'and',
 '<unk>',
 'fibers',
 'and',
 '<unk>',
 'mixed',
 'the',
 'dry',
 'fibers',
 'in',
 'a',
 'process',
 'used',
 'to',
 'make',
 'filters',
 '<eos>',
 'workers',
 'described',
 'clouds',
 'of',
 'blue',
 'dust',
 'that',
 'hung',
 'over',
 'parts',
 'of',
 'the',
 'factory',
 'even',
 'though',
 '<unk>',
 'fans',
 '<unk>',
 'the',
 'area',
 '<eos>',
 'there',
 "'s",
 'no',
 'question',
 'that',
 'some',
 'of',
 'those',
 'workers',
 'and',
 'managers',
 'contracted',
 '<unk>',
 'diseases',
 'said',
 '<unk>',
 'phillips',
 'vice',
 'president',
 'of',
 'human',
 'resources',
 'for',
 '<unk>',
 '&',
 '<unk>',
 '<eos>',
 'but',
 'you',
 'have',
 'to',
 'recognize',
 'that',
 'these',
 'events',
 'took',
 'place',
 'N',
 'years',
 'ago',
 '<eos>',
 'it',
 'has',
 'no',
 'bearing',
 'on',
 'our',
 'work',
 'force',
 'today',
 '<eos>',
 'yields',
 'on',
 'money-market',
 'mutual',
 'funds',
 'continued',
 'to',
 'slide',
 'amid',
 'signs',
 'that',
 'portfolio',
 'managers',
 'expect',
 'further',
 'declines',
 'in',
 'interest',
 'rates',
 '<eos>',
 'the',
 'average',
 'seven-day',
 'compound',
 'yield',
 'of',
 'the',
 'N',
 'taxable',
 'funds',
 'tracked',
 'by',
 '<unk>',
 "'s",
 'money',
 'fund',
 'report',
 'eased',
 'a',
 'fraction',
 'of',
 'a',
 'percentage',
 'point',
 'to',
 'N',
 'N',
 'from',
 'N',
 'N',
 'for',
 'the',
 'week',
 'ended',
 'tuesday',
 '<eos>',
 'compound',
 'yields',
 'assume',
 'reinvestment',
 'of',
 'dividends',
 'and',
 'that',
 'the',
 'current',
 'yield',
 'continues',
 'for',
 'a',
 'year',
 '<eos>',
 'average',
 'maturity',
 'of',
 'the',
 'funds',
 "'",
 'investments',
 '<unk>',
 'by',
 'a',
 'day',
 'to',
 'N',
 'days',
 'the',
 'longest',
 'since',
 'early',
 'august',
 'according',
 'to',
 'donoghue',
 "'s",
 '<eos>',
 'longer',
 'maturities',
 'are',
 'thought',
 'to',
 'indicate',
 'declining',
 'interest',
 'rates',
 'because',
 'they',
 'permit',
 'portfolio',
 'managers',
 'to',
 'retain',
 'relatively',
 'higher',
 'rates',
 'for',
 'a',
 'longer',
 'period',
 '<eos>',
 'shorter',
 'maturities',
 'are',
 'considered',
 'a',
 'sign',
 'of',
 'rising',
 'rates',
 'because',
 'portfolio',
 'managers',
 'can',
 'capture',
 'higher',
 'rates',
 'sooner',
 '<eos>',
 'the',
 'average',
 'maturity',
 'for',
 'funds',
 'open',
 'only',
 'to',
 'institutions',
 'considered',
 'by',
 'some',
 'to',
 'be',
 'a',
 'stronger',
 'indicator',
 'because',
 'those',
 'managers',
 'watch',
 'the',
 'market',
 'closely',
 'reached',
 'a',
 'high',
 'point',
 'for',
 'the',
 'year',
 'N',
 'days',
 '<eos>',
 'nevertheless',
 'said',
 '<unk>',
 '<unk>',
 '<unk>',
 'editor',
 'of',
 'money',
 'fund',
 'report',
 'yields',
 'may',
 '<unk>',
 'up',
 'again',
 'before',
 'they',
 '<unk>',
 'down',
 'because',
 'of',
 'recent',
 'rises',
 'in',
 'short-term',
 'interest',
 'rates',
 '<eos>',
 'the',
 'yield',
 'on',
 'six-month',
 'treasury',
 'bills',
 'sold',
 'at',
 'monday',
 "'s",
 'auction',
 'for',
 'example',
 'rose',
 'to',
 'N',
 'N',
 'from',
 'N',
 'N',
 '<eos>',
 'despite',
 'recent',
 'declines',
 'in',
 'yields',
 'investors',
 'continue',
 'to',
 'pour',
 'cash',
 'into',
 'money',
 'funds',
 '<eos>',
 'assets',
 'of',
 'the',
 'N',
 'taxable',
 'funds',
 'grew',
 'by',
 '$',
 'N',
 'billion',
 'during',
 'the',
 ...]
word_to_id = {}
id_to_word = {}

for i, word in enumerate(words):
    if word not in word_to_id:
        new_id = len(word_to_id)
        word_to_id[word] = new_id
        id_to_word[new_id] = word
print('corpus size :', len(words))        
print('vocab size : ', len(id_to_word))        
corpus size : 929589
vocab size :  10000
def load_vocab():
    words = open(file_path).read().replace('\n', '<eos>').strip().split()
    word_to_id = {}
    id_to_word = {}

    for i, word in enumerate(words):
        if word not in word_to_id:
            new_id = len(word_to_id)
            word_to_id[word] = new_id
            id_to_word[new_id] = word
    # print('corpus size :', len(words))        
    print('vocab size : ', len(id_to_word))       

    return word_to_id, id_to_word
word_to_id, id_to_word = load_vocab()
vocab size :  10000
dataset_dir = './ptb_dataset/'
train_file_name = 'ptb.train.txt'
window_size = 5
sample_size = 5
dataset = PTBDataset(file_path, window_size, sample_size)
vocab size: 10000
corpus size : 929589
context.shpape: (929579, 10) target.shape: (929579,)
len(dataset)
929579
dataset[100] # contexts, target
(array([76, 77, 64, 78, 79, 27, 28, 81, 82, 83]), 80)

3. 데이터 적재

batch_size=100
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
contexts, target = next(iter(dataloader))
contexts.size(), target.size()
(torch.Size([100, 10]), torch.Size([100]))

4. 모델 생성

class CBOW_NS_Model(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(CBOW_NS_Model, self).__init__()
        self.embedding_in = nn.Embedding(num_embeddings=vocab_size, embedding_dim = hidden_size)
        self.embedding_out = nn.Embedding(num_embeddings=vocab_size, embedding_dim = hidden_size)

    def forward(self, inputs, targets): 
        h = self.embedding_in(inputs) # h : (batch_size, hidden_size)
        h = h.mean(axis=1)
        target_W = self.embedding_out(targets) # target_W : (batch_size, hidden_size)
        out = torch.sum(target_W * h, axis=1)
        prob = F.sigmoid(out)
        return prob
vocab_size = 10000
hidden_size = 100

model = CBOW_NS_Model(vocab_size=vocab_size, hidden_size=hidden_size)
model.to(device)
CBOW_NS_Model(
  (embedding_in): Embedding(10000, 100)
  (embedding_out): Embedding(10000, 100)
)
contexts, target = contexts.to(device), target.to(device)
out = model(contexts, target)
out.shape, out.dtype
(torch.Size([100]), torch.float32)

5. 모델 설정 (손실함수, 옵티마이저 선택)

loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

6. 모델 훈련

def train_loop(model, dataloader, loss_fn, epochs, optimizer):  
    min_loss = 1000000  
    trigger = 0
    patience = 4     

    for epoch in range(epochs):
        model.train()
        train_loss = []

        for batch_data in dataloader:
            contexts = batch_data[0].to(device)
            target = batch_data[1].to(device)
            
            negative_sample = dataset.sampler.get_negative_sample(target)
            negative_sample = torch.LongTensor(negative_sample).to(device)
            
            optimizer.zero_grad()
            
            # positive sample 순전파
            positive_prob = model(contexts, target)
            correct_label = torch.ones(target.shape[0]).to(device)
            positive_loss = loss_fn(positive_prob, correct_label)

            # negative samples 순전파
            # negative_sample.shape : (batch_size, sample_size)
            negative_label = torch.zeros(target.shape[0]).to(device)
            negative_loss = 0
            for i in range(sample_size):
                negative_target = negative_sample[:, i]
                negative_prob = model(contexts, negative_target)
                negative_loss += loss_fn(negative_prob, negative_label)
                
            loss = positive_loss + negative_loss
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        train_epoch_loss = np.mean(train_loss)
        total_loss["train"].append(train_epoch_loss)

        print("Epoch: {}/{}, Train Loss={:.5f}".format(                    
               epoch + 1, epochs,
               total_loss["train"][-1]))                  
epochs = 12
total_loss = {"train": []}
%time train_loop(model, dataloader, loss_fn, epochs, optimizer)
Epoch: 1/12, Train Loss=4.46790
Epoch: 2/12, Train Loss=2.73813
Epoch: 3/12, Train Loss=2.32509
Epoch: 4/12, Train Loss=2.07252
Epoch: 5/12, Train Loss=1.90162
Epoch: 6/12, Train Loss=1.77771
Epoch: 7/12, Train Loss=1.68148
Epoch: 8/12, Train Loss=1.60607
Epoch: 9/12, Train Loss=1.54479
Epoch: 10/12, Train Loss=1.49309
Epoch: 11/12, Train Loss=1.44821
Epoch: 12/12, Train Loss=1.40948
CPU times: user 9min 1s, sys: 6.54 s, total: 9min 8s
Wall time: 9min 9s

7. 유사도 측정

# embedding from first model layer
# detach : https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()
embeddings.shape
(10000, 100)
def cos_similarity(x, y, eps=1e-8):
    '''코사인 유사도 산출

    :param x: 벡터
    :param y: 벡터
    :param eps: '0으로 나누기'를 방지하기 위한 작은 값
    :return:
    '''
    nx = x / (np.sqrt(np.sum(x ** 2)) + eps)
    ny = y / (np.sqrt(np.sum(y ** 2)) + eps)
    return np.dot(nx, ny)
def most_similar(query, word_to_id, id_to_word, word_matrix, top=5):
    # 1. 검색어를 꺼낸다
    if query not in word_to_id:
        print('%s(을)를 찾을 수 없습니다.' % query)
        return

    print('\n[query] ' + query)
    query_id = word_to_id[query]
    query_vec = word_matrix[query_id]

    # 2. 코사인 유사도 계산
    vocab_size = len(id_to_word)

    similarity = np.zeros(vocab_size)
    for i in range(vocab_size):
        similarity[i] = cos_similarity(word_matrix[i], query_vec)

    # 3. 코사인 유사도를 기준으로 내림차순으로 출력
    count = 0
    for i in (-1 * similarity).argsort():
        if id_to_word[i] == query:
            continue
        print(' %s: %s' % (id_to_word[i], similarity[i]))

        count += 1
        if count >= top:
            return
word_vecs = embeddings

# 가장 비슷한(most similar) 단어 뽑기
querys = ['you', 'year', 'car', 'toyota']
for query in querys:
    most_similar(query, word_to_id, id_to_word, word_vecs, top=5)
[query] you
 we: 0.7707014679908752
 i: 0.6373926401138306
 they: 0.5769265294075012
 triple-a: 0.43831726908683777
 she: 0.42899152636528015

[query] year
 month: 0.839719295501709
 week: 0.7326676249504089
 summer: 0.5758762359619141
 decade: 0.56444251537323
 spring: 0.5531997680664062

[query] car
 move: 0.5004615187644958
 buildings: 0.38829562067985535
 furor: 0.3863994777202606
 record: 0.36925071477890015
 plant: 0.3521578013896942

[query] toyota
 strengths: 0.43140339851379395
 marble: 0.38097456097602844
 sdi: 0.3780283033847809
 entertaining: 0.3772747218608856
 ford: 0.3746541738510132

댓글남기기