[Co Labor] Word2Vec을 사용한 AI 키워드 검색 구현

2024. 10. 26. 08:59프로젝트: Co Laobr

기능 설명 📘

  • 키워드를 AI를 활용하여 검색
  • 입력한 키워드와 유사한 정보 가지고있는 Review, Job, Enterprise JSON으로 리턴

구현 방법 🛠

  1. 데이터 추출 및 학습: MySQL 데이터베이스에서 데이터를 추출하고 Word2Vec 모델을 학습하여 유사 키워드를 생성하는 모델 생성
  2. 유사 키워드 검색: 사용자가 입력한 키워드를 바탕으로 학습된 모델을 사용하여 유사한 키워드를 생성
  3. 데이터베이스 검색: 유사 키워드를 사용하여 데이터베이스에서 관련된 데이터를 검색
  4. 검색 결과 반환: 검색된 데이터를 JSON 형식으로 반환

build.gradle

deeplearning4j-nlp, nd4j 의존성 추가

plugins {
	id 'java'
	id 'org.springframework.boot' version '3.3.1'
	id 'io.spring.dependency-management' version '1.1.5'
}

group = 'pelican'
version = '0.0.1-SNAPSHOT'

java {
	toolchain {
		languageVersion = JavaLanguageVersion.of(17)
	}
}

configurations {
	compileOnly {
		extendsFrom annotationProcessor
	}
}

repositories {
	mavenCentral()
}

dependencies {
	implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
	implementation 'org.springframework.boot:spring-boot-starter-thymeleaf'
	implementation 'org.springframework.boot:spring-boot-starter-web'
	implementation 'org.deeplearning4j:deeplearning4j-nlp:1.0.0-beta7'
	implementation 'org.nd4j:nd4j-native-platform:1.0.0-beta7'
	
	compileOnly 'org.projectlombok:lombok'
	developmentOnly 'org.springframework.boot:spring-boot-devtools'
	runtimeOnly 'com.mysql:mysql-connector-j'
	annotationProcessor 'org.springframework.boot:spring-boot-configuration-processor'
	annotationProcessor 'org.projectlombok:lombok'
	testImplementation 'org.springframework.boot:spring-boot-starter-test'
	testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
}

tasks.named('test') {
	useJUnitPlatform()
}

 

데이터 추출

데이터 추출 및 학습을 수행하는 DataPreparationService 객체의 fetchDataAndTrainModel 메서드는 다음과 같은 기능을 담당한다.

  1. Job의 Description, Title, Gender → List로 추출
  2. Review의 Pros, Cons, Title → List로 추출
  3. Enterprise의 Description, Name, Address → List로 추출
  4. 추출된 List 합치기
    @PostConstruct
    public void fetchDataAndTrainModel() {
        try {
            logger.info("Starting data extraction and model training...");

            // 데이터 추출
            logger.info("Extracting data from Job repository...");
            List<String> jobDescriptions = getSafeList(jobRepository.findAll().stream()
                    .map(Job::getDescription)
                    .collect(Collectors.toList()));

            List<String> jobTitles = getSafeList(jobRepository.findAll().stream()
                    .map(Job::getTitle)
                    .collect(Collectors.toList()));

            List<String> jobGenders = getSafeList(jobRepository.findAll().stream()
                    .map(Job::getGender)
                    .collect(Collectors.toList()));

            logger.info("Extracting data from Review repository...");
            List<String> reviewPros = getSafeList(reviewRepository.findAll().stream()
                    .map(Review::getPros)
                    .collect(Collectors.toList()));

            List<String> reviewCons = getSafeList(reviewRepository.findAll().stream()
                    .map(Review::getCons)
                    .collect(Collectors.toList()));

            List<String> reviewTitles = getSafeList(reviewRepository.findAll().stream()
                    .map(Review::getTitle)
                    .collect(Collectors.toList()));

            logger.info("Extracting data from Enterprise repository...");
            List<String> enterpriseDescriptions = getSafeList(enterpriseRepository.findAll().stream()
                    .map(Enterprise::getDescription)
                    .collect(Collectors.toList()));

            List<String> enterpriseNames = getSafeList(enterpriseRepository.findAll().stream()
                    .map(Enterprise::getName)
                    .collect(Collectors.toList()));

            List<String> enterpriseAddresses = getSafeList(enterpriseRepository.findAll().stream()
                    .map(Enterprise::getAddress)
                    .collect(Collectors.toList()));

            // 모든 데이터를 하나의 리스트로 합치기
            logger.info("Combining all extracted data into a single list...");
            List<String> sentences = Stream.of(jobDescriptions, jobTitles, jobGenders, reviewPros, reviewCons, reviewTitles, enterpriseDescriptions, enterpriseNames, enterpriseAddresses)
                    .flatMap(List::stream)
                    .collect(Collectors.toList());

            logger.info("Data extraction completed. Total number of sentences: {}", sentences.size());

            // 모델 학습
            logger.info("Starting Word2Vec model training...");
            trainWord2VecModel(sentences);
            logger.info("Word2Vec model training completed.");
        } catch (Exception e) {
            logger.error("An error occurred during data extraction and model training", e);
            throw new RuntimeException(e);
        }
    }

실제 모델 학습은 trainWord2VecModel 에서 수행하게 된다. 

Word2Vec 모델 학습

모델의 학습을 어떤 코드로 했는지 알기 전에 Word2Vec에 대해 알아야 한다.

Word2Vec은 단어를 고정된 크기의 벡터로 변환한다. 변환된 벡터를 이용해서 단어 간의 유사성을 측정하고, 벡터 공간에서 유사한 단어들은 가까운 거리에 위치하게 된다. 이때 코사인 유사도가 사용된다. 코사인 유사도는 두 벡터가 이루는 각도의 코사인을 측정하여 유사도를 계산하는 방법이다.

Word2Vec에서는 두 가지 주요 학습 알고리즘이 존재한다.

  • CBOW는 주어진 Context 단어들을 사용하여 중심 단어를 예측하는 방식이다.
  • Skip-gram은 주어진 중심 단어로 주변 단어를 예측하는 방식이다. 코드에서는 Skip-gram 방식을 사용하였다.
private void trainWord2VecModel(List<String> sentences) {
    try {
        logger.info("Initializing Word2Vec model...");

        // 문장을 반복자로 변환
        CollectionSentenceIterator sentenceIterator = new CollectionSentenceIterator(sentences);

        // 토크나이저 팩토리 설정
        DefaultTokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());

        // Word2Vec 모델 빌더 설정
        Word2Vec vec = new Word2Vec.Builder()
                .minWordFrequency(1)  // 단어의 최소 등장 빈도
                .iterations(5)        // 학습 반복 횟수
                .layerSize(100)       // 벡터의 차원 수
                .seed(42)             // 랜덤 시드
                .windowSize(5)        // 컨텍스트 윈도우 크기
                .iterate(sentenceIterator)  // 문장 반복자 설정
                .tokenizerFactory(tokenizerFactory)  // 토크나이저 팩토리 설정
                .build();

        logger.info("Fitting Word2Vec model...");

        // 모델 학습
        vec.fit();

        logger.info("Saving Word2Vec model to word2vecModel.txt...");

        // 모델 저장
        WordVectorSerializer.writeWord2VecModel(vec, "word2vecModel.txt");
        logger.info("Word2Vec model saved successfully.");
    } catch (Exception e) {
        logger.error("An error occurred during Word2Vec model training", e);
        throw new RuntimeException(e);
    }
}

텍스트 데이터를 단어 단위로 분리하기 위해 토크나이저 단계가 필요하다. 코드에서 토크나이저 단계는 다음과 같이 수행된다.

  1. CollectionSentenceIterator sentenceIterator = new CollectionSentenceIterator(sentences); : 문장을 순차적으로 읽기 위한 반복자를 설정한다.
  2. DefaultTokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); : 텍스트를 단어 단위로 분리할 수 있는 토크나이저를 생성한다.
  3. tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor()); : 모든 단어를 소문자로 변환하고 특수 문자를 제거하는 등 단어를 전처리하여 모델 학습의 효율성을 높인다.

예를 들어 설명하면 토크나이저는 다음과 같은 과정을 거치는 것이다.

  • 입력 문장: "The quick brown fox jumps over the lazy dog."
  • 전처리: "the quick brown fox jumps over the lazy dog" (소문자로 변환)
  • 토큰화: ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] (단어 단위로 분리)

다음은 Word2Vec 모델을 설정하는 부분이다.

Word2Vec vec = new Word2Vec.Builder()
        .minWordFrequency(1)  // 단어의 최소 등장 빈도
        .iterations(5)        // 학습 반복 횟수
        .layerSize(100)       // 벡터의 차원 수
        .seed(42)             // 랜덤 시드
        .windowSize(5)        // 컨텍스트 윈도우 크기
        .iterate(sentenceIterator)  // 문장 반복자 설정
        .tokenizerFactory(tokenizerFactory)  // 토크나이저 팩토리 설정
        .build();
  • minWordFrequency(1): 단어의 최소 등장 빈도를 1로 설정하여, 모든 단어를 벡터화한다. 만약 2로 설정한다면 2번 이상 등장한 단어만 대상으로 하게 될 것이다.
  • iterations(5): 학습 반복 횟수를 5로 설정한다.
  • layerSize(100): 단어 벡터의 차원 수를 100으로 설정한다. 즉, 각 단어를 100차원의 벡터로 표현하는 것이다. 벡터의 차원 수가 클수록 단어 간 관계를 잘 표현할 수 있다.
  • seed(42): 랜덤 시드를 42로 설정하여, 학습 결과의 재현성을 보장한다.
  • windowSize(5): 컨텍스트 윈도우 크기를 5로 설정하여 좌우 5개 단어를 컨텍스트로 사용한다.
  • iterate(sentenceIterator): 문장 반복자를 설정한다.
  • tokenizerFactory(tokenizerFactory): 토크나이저 팩토리를 설정한다.

설정된 모델은 vec.fit() 을 통해 설정된 파라미터로 모델을 학습시키고 WordVectorSerializer.writeWord2VecModel(vec, "word2vecModel.txt"); 를 통해 학습된 모델을 파일로 저장한다. 이 파일은 나중에 로드하여 사용할 수 있다.

모델 사용

package pelican.co_labor.service;

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;

import java.io.File;
import java.util.Collections;
import java.util.List;

@Service
public class KeywordSearchService {

    private static final Logger logger = LoggerFactory.getLogger(KeywordSearchService.class);
    private Word2Vec word2Vec;

    public KeywordSearchService() {
        try {
            loadModel();
        } catch (Exception e) {
            logger.error("An error occurred while loading the Word2Vec model", e);
        }
    }

    private void loadModel() {
        File modelFile = new File("word2vecModel.txt");
        if (modelFile.exists()) {
            logger.info("Loading Word2Vec model from file...");
            word2Vec = WordVectorSerializer.readWord2VecModel(modelFile);
            logger.info("Word2Vec model loaded successfully.");
        } else {
            logger.warn("Word2Vec model file not found. Please train the model first.");
        }
    }

    public List<String> searchSimilarWords(String keyword) {
        if (word2Vec == null) {
            logger.warn("Word2Vec model is not loaded. Please train the model first.");
            return Collections.emptyList();
        }

        if (!word2Vec.hasWord(keyword)) {
            logger.warn("Keyword '{}' not found in Word2Vec model vocabulary.", keyword);
            return Collections.emptyList();
        }

        return (List<String>) word2Vec.wordsNearest(keyword, 10);
    }
}

return (List<String>) word2Vec.wordsNearest(keyword, 10); : 주어진 키워드와 유사도가 가장 가까운 단어 10개를 리턴한다. 원래 wordsNearestSum 을 사용했는데, 해당 메서드는 여러 단어의 벡터 합계를 사용하기에 단일 키워드를 사용했을 때 예상치 못한 결과가 나와서 변경하였다.

AI 검색 컨트롤러

package pelican.co_labor.controller;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import pelican.co_labor.domain.enterprise.Enterprise;
import pelican.co_labor.domain.job.Job;
import pelican.co_labor.domain.review.Review;
import pelican.co_labor.service.KeywordSearchService;
import pelican.co_labor.service.SearchService;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

@RestController
public class AISearchController {

    private static final Logger logger = LoggerFactory.getLogger(AISearchController.class);

    @Autowired
    private KeywordSearchService keywordSearchService;

    @Autowired
    private SearchService searchService;

    @GetMapping("/ai-search")
    public Map<String, Object> search(@RequestParam String keyword) {
        Map<String, Object> response = new HashMap<>();

        List<String> similarKeywords = keywordSearchService.searchSimilarWords(keyword);
        logger.info("Similar keywords: {}", similarKeywords);

        if (similarKeywords.isEmpty()) {
            response.put("message", "No similar words found for the given keyword in the database.");
            return response;
        }

        // DB 검색
        List<Job> jobs = searchService.searchJobs(similarKeywords);
        List<Review> reviews = searchService.searchReviews(similarKeywords);
        List<Enterprise> enterprises = searchService.searchEnterprises(similarKeywords);

        response.put("jobs", jobs);
        response.put("reviews", reviews);
        response.put("enterprises", enterprises);

        return response;
    }
}

유사한 단어 리스트를 similarKeywords 로 반환하고 List를 검색하는 메서드를 searchService 에 추가하였다.

package pelican.co_labor.service;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import pelican.co_labor.domain.enterprise.Enterprise;
import pelican.co_labor.domain.job.Job;
import pelican.co_labor.domain.review.Review;
import pelican.co_labor.repository.enterprise.EnterpriseRepository;
import pelican.co_labor.repository.job.JobRepository;
import pelican.co_labor.repository.review.ReviewRepository;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

@Service
public class SearchService {

    @Autowired
    private JobRepository jobRepository;

    @Autowired
    private ReviewRepository reviewRepository;

    @Autowired
    private EnterpriseRepository enterpriseRepository;

    public List<Job> searchJobs(String keyword) {
        return jobRepository.searchJobs(keyword);
    }

    public List<Review> searchReviews(String keyword) {
        return reviewRepository.searchReviews(keyword);
    }

    public List<Enterprise> searchEnterprises(String keyword) {
        return enterpriseRepository.searchEnterprises(keyword);
    }

    public List<Job> searchJobs(List<String> keywords) {
        Set<Job> jobs = new HashSet<>();
        for (String keyword : keywords) {
            jobs.addAll(jobRepository.searchJobs(keyword));
        }
        return new ArrayList<>(jobs);
    }

    public List<Review> searchReviews(List<String> keywords) {
        Set<Review> reviews = new HashSet<>();
        for (String keyword : keywords) {
            reviews.addAll(reviewRepository.searchReviews(keyword));
        }
        return new ArrayList<>(reviews);
    }

    public List<Enterprise> searchEnterprises(List<String> keywords) {
        Set<Enterprise> enterprises = new HashSet<>();
        for (String keyword : keywords) {
            enterprises.addAll(enterpriseRepository.searchEnterprises(keyword));
        }
        return new ArrayList<>(enterprises);
    }



}

 

이렇게 만든 AI 검색은 단어의 유사도를 사용했지만 성능은 잘 나오지 않았다. 정확한 이유는 알 수 없지만 일단 학습 데이터가 너무 적고 학습이 잘 되지 않은 느낌이었다. 그래서 어떤 방법을 썼느냐?는 다음 글을 보자!