챗봇 코드가 필요한 사람은 이곳으로 가십시오, 이 포스트는 이 페이지 코드의 도움을 받았으며 이곳에 더 잘 정리되어 있습니다.
작년에 했던, 시짓는 AI프로젝트가 조금 찜찜하게 끝나고. (관련링크) 아쉬웠던 차에 올해도 AI관련 프로젝트를 할 기회가 생겨서 냉큼 참여하게 되었다.
이번에는 챗봇을 만들어보는 일인데.. 확실히 특정 기기만 사용해야 한다는 제약상황이 사라지니까 한결 살 것 같다. (작년에는 JetsonNano만 사용해야 했어서 tokenizer, torch..등과 os, python버전을 맞추느라 무진장 애먹었다..)
사용 환경은 다음과 같다.
Python version : Python 3.8.1
transformer : '4.19.2'
torch : '1.11.0+cpu'
transformer pretrainedModel : skt/kogpt-base-v2
chatbot baseData : haven-jeon/KoGPT2-charbot
model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')
으로 skt에서 공개한(감사합니다 ㅠ) kogpt2모델을 불러오고
print("start")
//
print("end")
사이의 코드를 통해 모델을 학습시킨다. (이 때 tokenizer가 필요하며, 역시 skt/kogpt2-base-v2모델을 사용한다)
학습이 끝난 모델은 torch.save(모델, 경로)를 통해서 학습시키며
torch.save(model, 'models/' + str(nowDate) + '_model.bin')
torch.load('modelpath')
나중에 torch.load(path)로 가져와 활용할 수 있다.
Torch Code는 이곳의 도움을 받았다.(감사합니다 ㅠ) https://wikidocs.net/157001 KoGPT2 챗봇 만들기
from regex import D
import setModelData
import numpy as np
import pandas as pd
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader, Dataset
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel
import re
Q_TKN = "<usr>"
A_TKN = "<sys>"
BOS = '</s>'
EOS = '</s>'
MASK = '<unused0>'
SENT = '<unused1>'
PAD = '<pad>'
koGPT2_TOKENIZER = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2",
bos_token=BOS, eos_token=EOS, unk_token='<unk>',
pad_token=PAD, mask_token=MASK)
model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')
import urllib.request
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv",
filename="data/ChatBotData.csv",
)
Chatbot_Data = pd.read_csv("data/ChatBotData.csv")
# Test 용으로 300개 데이터만 처리한다.
Chatbot_Data = Chatbot_Data[:300]
Chatbot_Data.head()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_set = setModelData.ChatbotDataset(Chatbot_Data, max_len=40)
#윈도우 환경에서 num_workers 는 무조건 0으로 지정, 리눅스에서는 2
train_dataloader = DataLoader(train_set, batch_size=32, num_workers=0, shuffle=True, collate_fn=setModelData.collate_batch,)
model.to(device)
model.train()
learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss(reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
import datetime as dt
now = dt.datetime.now()
nowDate = now.strftime('%m%d_%H%M')
# torch.save(model.state_dict(), 'models/' + str(nowDate) + '_dict.bin')
# torch.save(model, 'models/' + str(nowDate) + '_model.bin')
epoch = 10
Sneg = -1e18
print ("start")
for epoch in range(epoch):
for batch_idx, samples in enumerate(train_dataloader):
optimizer.zero_grad()
token_ids, mask, label = samples
out = model(token_ids)
out = out.logits #Returns a new tensor with the logit of the elements of input
mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2)
mask_out = torch.where(mask_3d == 1, out, Sneg * torch.ones_like(out))
loss = criterion(mask_out.transpose(2, 1), label)
# 평균 loss 만들기 avg_loss[0] / avg_loss[1] <- loss 정규화
avg_loss = loss.sum() / mask.sum()
avg_loss.backward()
# 학습 끝
optimizer.step()
print ("end")
torch.save(model.state_dict(), 'models/' + str(nowDate) + '_dict.bin')
torch.save(model, 'models/' + str(nowDate) + '_model.bin')
sent = "0" # 0=일상, 1=부정, 2=긍정
with torch.no_grad():
while 1:
q = input("user > ").strip()
if q == "quit":
break
a = ""
while 1:
input_ids = torch.LongTensor(koGPT2_TOKENIZER.encode(Q_TKN + q + SENT + sent + A_TKN + a)).unsqueeze(dim=0)
pred = model(input_ids)
pred = pred.logits
gen = koGPT2_TOKENIZER.convert_ids_to_tokens(torch.argmax(pred, dim=-1).squeeze().numpy().tolist())[-1]
if gen == EOS:
break
a += gen.replace("▁", " ")
print("Chatbot > {}".format(a.strip()))%
'개발 조각글' 카테고리의 다른 글
Unity - 인스펙터 두개 띄워놓기 (0) | 2022.07.18 |
---|---|
Python - TextToCsv (0) | 2022.07.17 |
Processing - Video_Capture가 작동하지 않을 때 (0) | 2022.06.04 |
Phthon - Web 정보 긁어오기 (0) | 2022.05.30 |
Unity UI 컴포넌트 캐싱[0] (0) | 2022.05.20 |