본문 바로가기
PROGRAMING📚/BigData📑

인공지능 - torch

Ta이니 2024. 9. 29.
728x90
반응형

머신러닝

  1. 케라스 : 학습, 연구용
  2. 텐서플로 : 기업용 서비스
  3. 파이토치 : 기계 학습용 무료

C:/workspace/anaconda3/condabin/conda.bat

conda.bat 경로 찾아서 Load Environment 해줌

 

다음과 같이 뜨면 + 버튼을 눌러서 pytorch package 를 설치해줌

install package 해주고 패키지가 설치가 될 때 까지 기다림

왼쪽 아래에서도 파이썬 패키지 설치 가능


간단한 인공 신경망, 순방향 계산만 처리

입력 뉴런 : 5개 , 출력 뉴런: 3개, 15개의 시냅스를 가지고 있는 인공 신경망을 생성

import torch
from PIL.Image import Transform
# 간단한 인공 신경망, 순방향 계산만 처리
# 입력 뉴런: 5개, 출력 뉴런: 3개, 15개의 시냅스를 가지고 있는 인공 신경망을 생성

# nn.Linear는 가중치(weight)와 편향(bias)을 사용하여 입력 데이터에 선형 변환을
# 적용하는 신경망에서 사용되는 선형 레이어.
hello = torch.nn.Linear(5,3)
# randn 표준정규분포를 만들어주는 임의의 수로 2행5열의 2차원 배열 생성
data = torch.randn(2,5)  # 임의의 입력 벡터 2개 만듦

print(data)
print(hello(data))

 

MNIST(Mixed National Institute of Standards and Technology) 데이터셋은 손으로 쓴 숫자(0~9)의 이미지로 구성된 데이터셋으로, 인공지능과 머신러닝 모델을 훈련시키는 데 자주 사용됩니다. 각 이미지는 28x28 픽셀 크기의 흑백 이미지로, 총 7만 개의 숫자 이미지(훈련용 6만 개, 테스트용 1만 개)로 이루어져 있습니다.

이 데이터셋은 주로 이미지 분류 및 패턴 인식, 특히 딥러닝 모델을 훈련시키는 기초적인 예제로 많이 사용되며, 간단한 신경망부터 복잡한 CNN(Convolutional Neural Network)까지 다양한 모델의 성능을 실험하는 데 사용됩니다.

torchvision 설치하기

torchvision은 PyTorch에서 컴퓨터 비전 작업을 쉽게 할 수 있도록 지원하는 라이브러리입니다. 주로 이미지 처리와 관련된 다양한 기능을 제공합니다. 주요 기능들은 다음과 같습니다:

  1. 데이터셋: 다양한 유명 데이터셋(CIFAR, MNIST, ImageNet 등)을 손쉽게 불러올 수 있게 해주는 API를 제공합니다. 데이터를 자동으로 다운로드하고 전처리해주는 기능도 포함되어 있어 학습에 필요한 데이터를 쉽게 준비할 수 있습니다.
  2. 변환(Transforms): 이미지 데이터를 전처리하기 위한 다양한 변환 기능을 제공합니다. 이미지 크기 조정, 자르기, 뒤집기, 정규화 등 데이터 증강 및 전처리를 위한 함수들이 포함되어 있습니다.
  3. 모델: 사전 훈련된 모델을 제공해 컴퓨터 비전 관련 작업을 빠르게 수행할 수 있습니다. 예를 들어, ResNet, VGG, MobileNet 등의 네트워크 구조가 포함되어 있으며, 이들을 바로 사용하거나 커스터마이즈할 수 있습니다.
  4. 유틸리티 함수: 이미지 입출력, 시각화 등 컴퓨터 비전 작업에 유용한 다양한 유틸리티 기능도 포함되어 있습니다.

이 라이브러리는 PyTorch를 사용한 이미지 인식, 분류, 객체 검출 등의 작업을 간편하게 수행할 수 있도록 도와줍니다.

torchsummary 설치하기

 

torchsummary는 PyTorch 모델의 구조와 파라미터 수를 요약해서 보여주는 패키지입니다. 모델의 각 계층(layer)이 어떻게 구성되어 있는지, 각 레이어의 출력 크기와 파라미터 개수를 한눈에 확인할 수 있도록 도와줍니다. Keras의 model.summary() 함수와 비슷한 기능을 제공하며, 주로 딥러닝 모델을 디버깅하거나 모델 아키텍처를 이해할 때 유용합니다.

주요 기능:

  1. 모델 요약: 주어진 모델에 대한 계층별 요약 정보를 출력합니다.
    • 각 계층의 이름과 유형
    • 출력 크기 (Output Shape)
    • 계층별 파라미터 수
  2. 전체 파라미터 수: 모델의 학습 가능한 파라미터와 전체 파라미터 수를 계산해 보여줍니다.
  3. 메모리 사용량: 각 레이어에서 소비하는 메모리(텐서 크기)를 파악할 수 있어, 모델이 GPU 메모리에 적합한지 여부를 빠르게 확인할 수 있습니다.

사용법 예시:

from torchsummary import summary
import torch
import torch.nn as nn

# 간단한 CNN 모델 정의
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 모델 생성
model = SimpleCNN()

# 모델 요약 (입력 크기는 (채널, 높이, 너비) 순서로 입력)
summary(model, (1, 28, 28))

이 코드를 실행하면 모델의 각 레이어, 출력 크기, 그리고 파라미터 수를 출력해 줍니다. torchsummary는 복잡한 모델의 아키텍처를 쉽게 이해하고 분석할 수 있도록 도와줍니다.

from torchvision.datasets import MNIST
import  torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from torchsummary import summary

# 입력 데이터의 변환 방식 지정
# Mnist 데이터는 numpy형식으로 되어 있기 때문에 pytoch tensor로 전환하기 위함
rules = transforms.Compose([transforms.ToTensor()])

 

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import  torch.nn as nn
from torchsummary import summary

print("훈련용 데이터")
# 훈련용 데이터
train_loader =DataLoader(MNIST('mnist',train=True, download=True, transform=rules),
batch_size =500, shuffle=True)
print(train_loader)

훈련용 데이터
<torch.utils.data.dataloader.DataLoader object at 0x0000025EEC28FAD0>

print("평가용 데이터")
#평가용 데이터
test_loader = DataLoader(MNIST('mnist', train =False, download=True, transform=rules),
                         batch_size= 500, shuffle=False)
images, labels = next(iter(train_loader))
print(images[0])
print(labels[0])

평가용 데이터
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 1.0000, 0.5020, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.7490, 1.0000, 0.2510, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.2510, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.5020, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.7490, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.7490, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.7490, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.2510, 0.5020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 0.5020, 0.0000, 0.0000, 0.0000, 0.0000,
          0.5020, 1.0000, 1.0000, 1.0000, 1.0000, 0.7490, 0.2510, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.2510,
          1.0000, 1.0000, 0.7490, 1.0000, 0.7490, 0.5020, 0.5020, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.5020, 1.0000, 0.7490, 0.0000, 0.0000, 0.0000, 1.0000,
          0.5020, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.2510, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 1.0000, 0.2510, 0.0000, 0.0000, 0.5020, 1.0000,
          1.0000, 0.2510, 0.0000, 0.0000, 0.0000, 0.0000, 0.5020, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.7490, 1.0000,
          0.2510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2510, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.5020, 1.0000,
          0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.7490, 0.2510, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 1.0000, 1.0000, 0.5020, 0.0000, 0.0000, 0.0000, 1.0000,
          1.0000, 0.7490, 0.5020, 0.5020, 0.7490, 1.0000, 0.2510, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.5020, 1.0000, 1.0000, 1.0000, 0.7490, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 0.7490, 1.0000, 0.5020, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.2510, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 0.7490, 0.5020, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.2510, 0.5020, 1.0000, 1.0000, 0.7490, 0.7490,
          0.7490, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])
tensor(6)

Sequential

Sequential은 딥러닝 라이브러리인 Keras에서 사용되는 모델 구성 방식 중 하나로, 레이어들을 순차적으로 쌓는 간단한 모델을 만드는 데 사용됩니다. 이름에서 알 수 있듯이, 각 레이어는 순차적으로 진행되며, 하나의 레이어 출력이 다음 레이어의 입력이 됩니다.

주요 특징:

  1. 순차적 모델 구조: 레이어들이 차례대로 쌓이므로 입력에서 출력까지의 흐름이 직관적이고 간단합니다.
  2. 고정된 흐름: Sequential 모델은 레이어가 고정된 순서로만 진행됩니다. 따라서 복잡한 연결(병렬 구조, 여러 입력 또는 출력 등)은 지원하지 않습니다.
  3. 주로 기본적인 네트워크에 사용: 다중 퍼셉트론, CNN, RNN 등 일반적인 딥러닝 모델을 만들 때 주로 사용됩니다.

예시 코드:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Sequential 모델 생성
model = Sequential()

# 입력층과 은닉층 추가
model.add(Dense(64, input_dim=100, activation='relu'))  # input_dim은 입력의 크기
model.add(Dense(32, activation='relu'))  # 두 번째 은닉층
model.add(Dense(1, activation='sigmoid'))  # 출력층 (이진 분류를 위해 sigmoid 사용)

# 모델 컴파일
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 모델 요약 출력
model.summary()

주요 메서드:

  • add(): 새로운 레이어를 추가합니다.
  • compile(): 모델을 학습할 준비를 합니다. 옵티마이저, 손실 함수, 평가 지표 등을 설정합니다.
  • fit(): 모델을 학습시키는 메서드로, 학습 데이터와 반복 횟수(epoch)를 지정합니다.
  • evaluate(): 학습된 모델을 테스트 데이터로 평가합니다.
  • predict(): 학습된 모델로 새로운 데이터를 예측합니다.

언제 사용하나요?

  • 단순한 레이어 구조: 입력에서 출력까지 직선적인 흐름을 갖는 단순한 모델을 만들 때 적합합니다.
  • 빠른 프로토타이핑: 복잡한 구조를 필요로 하지 않는 경우, 간단하게 모델을 구성하고 실험할 수 있습니다.

복잡한 모델(다중 입력/출력, 병렬 처리 등)을 만들고자 한다면, Functional APIModel subclassing을 고려하는 것이 좋습니다.

손글씨 인식 딥러닝 실습

import  torch.nn as nn

# 파이토치 DNN를 Sequential 모델로 구현
model = nn.Sequential(
  nn.Flatten(), # 2차원 데이터를 1차원으로 변형

  #뉴런이 784개를 넣으면 128개의 필터가 있다
  #입력은 784개이고, weight 가 784개, bias는 1개 , 출력이 128개라는 의미
  nn.Linear(784, 128),
  nn.ReLU(), #음수의 값을 0으로 변환시켜주는 필터역할
  nn.Dropout(p=0.2), #128개의 출력값 중에 임의로 20%를 0으로 만들기 위함
  nn.Linear(128,10), #128개의 출력값을 10개의 출력값으로 변환
  nn.Softmax(dim=1), #10개의 확율값이 있는 1번축을 따라가며 큰값을 가져와라
)

print('{0:=^50}'.format("DNN Summary"))
summary(model, (1, 28,28))

nn.Flatten()

Flatten은 Keras나 딥러닝에서 사용되는 레이어로, 다차원 배열을 1차원 배열로 변환하는 역할을 합니다. 주로 CNN(Convolutional Neural Network)에서 합성곱 및 풀링 레이어를 거친 출력 데이터를 완전 연결(Dense) 층에 전달하기 전에 사용됩니다.

nn.Linear(784, 128)

nn.Linear(784, 128)은 PyTorch에서 784차원의 입력을 128차원의 출력으로 변환하는 선형(fully connected) 레이어를 정의하는 코드입니다.

동작 원리:

Linear 레이어는 다음 수식을 수행합니다:

y=Wx+by = Wx + b

y=Wx+b

  • x: 입력 벡터
  • W: 가중치 행렬
  • b: 편향(바이어스) 벡터
  • y: 출력 값

nn.ReLU()는 PyTorch에서 사용하는 Rectified Linear Unit (ReLU) 활성화 함수입니다. ReLU는 비선형 활성화 함수로, 입력이 0보다 작으면 0을 출력하고, 0보다 크거나 같으면 입력 값을 그대로 반환합니다.

ReLU 함수의 수식:

ReLU(x)=max⁡(0,x)\text{ReLU}(x) = \max(0, x)

ReLU(x)=max(0,x)

  • 입력 x: 네트워크에서 전달된 입력 값
  • xx
  • 출력: x가 0보다 크면 그대로 출력, 0 이하이면 0을 출력
  • xx

nn.Dropout(p=0.2)는 PyTorch에서 사용하는 드롭아웃(Dropout) 레이어로, 신경망의 과적합(overfitting)을 방지하기 위한 정규화 기법입니다.

nn.Softmax(dim=1)은 PyTorch에서 사용하는 소프트맥스(Softmax) 활성화 함수로, 주로 다중 클래스 분류 문제에서 사용됩니다. 이 함수는 입력 텐서의 각 원소를 확률로 변환하여, 모든 원소의 합이 1이 되도록 합니다.

주요 특징:

  • 확률 변환: 소프트맥스 함수는 각 클래스에 대한 로짓(logit)을 입력으로 받아서, 해당 클래스가 선택될 확률로 변환합니다.
  • 차원 지정: dim=1은 소프트맥스 함수가 각 행(row)에 대해 적용된다는 것을 의미합니다. 즉, 각 샘플의 클래스 확률을 계산합니다.

# 결과 값들이 0에 가까워 지면 0.01에 해당 되는 값을 좌우의 값들의 일부를 삭제 한다.
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)는 PyTorch에서 Adam 옵티마이저를 생성하는 코드입니다. 이 옵티마이저는 신경망의 가중치를 업데이트하는 데 사용됩니다. 각 구성 요소를 자세히 살펴보면 다음과 같습니다:

구성 요소 설명:

  1. torch.optim.Adam:
    • Adam(Adaptive Moment Estimation)은 모멘텀적응형 학습률을 결합한 옵티마이저입니다.
    • Adam은 학습 과정에서 각 매개변수에 대한 학습률을 동적으로 조정하여 효과적인 최적화를 제공합니다.
  2. model.parameters():
    • 현재 모델의 모든 파라미터(가중치 및 편향)를 반환합니다. 이 파라미터들은 학습 과정에서 업데이트될 대상입니다.
    • 즉, Adam 옵티마이저는 이 파라미터들을 조정하여 모델의 성능을 향상시키기 위해 사용됩니다.
  3. lr=0.01:
    • lrlearning rate(학습률)를 나타냅니다. 학습률은 가중치를 업데이트할 때의 스텝 크기를 결정합니다.
    • 0.01은 초기 학습률을 의미하며, 이 값은 문제에 따라 조정될 수 있습니다. 너무 크면 발산할 수 있고, 너무 작으면 학습 속도가 느려질 수 있습니다.

전체 동작 과정:

이 코드는 Adam 옵티마이저를 초기화하고, 지정된 학습률로 모델의 가중치를 업데이트할 준비를 합니다. 이후 학습 루프에서 optimizer.step()을 호출하여 가중치를 업데이트할 수 있습니다.

import torch
import torch.nn as nn

# 모델 정의
model = nn.Linear(10, 2)  # 간단한 선형 모델

# 옵티마이저 생성
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 손실 함수 정의
criterion = nn.CrossEntropyLoss()

# 가상의 입력과 타겟 생성
input_data = torch.randn(5, 10)  # 5개 샘플
target = torch.tensor([0, 1, 0, 1, 0])  # 타겟 레이블

# 학습 루프
for epoch in range(100):
    optimizer.zero_grad()  # 기울기 초기화
    output = model(input_data)  # 예측
    loss = criterion(output, target)  # 손실 계산
    loss.backward()  # 기울기 계산
    optimizer.step()  # 가중치 업데이트
# 훈련용 데이터, shuffle 하면 기계학습의 정확도를 높일 수 있음, batch :: 변환할 갯수
train_loader = DataLoader(
  MNIST('mnist', train=True, download=True, transform=rules),
  batch_size=500, shuffle=True
)

# 최적화 함수
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 손실함수
criterion = nn.CrossEntropyLoss()

# 인공신경망 학습
for epoch in range(5):
  for data in train_loader:
    inputs, labels = data # inputs(입력값), labels(결과값)
    outputs = model(inputs)

    loss = criterion(outputs, labels) # 손실값 계산

    # 역전파 알고리즘으로 가중치 보정
    model.zero_grad() # 기울기를 구하기 전 0으로 초기화
    loss.backward() # 역전파 Adam 사용해서 출력층에서 입력층으로 반대로 이동하며 기울기 보정
    optimizer.step() # 관련된 기울기의 보정치를 구하는 역활을 함.

  print('Epoch : {},'.format(epoch),
        'Loss : {:.3f}'.format(loss.item()))

# 인공신경망 평가
correct = 0
for images, labels in test_loader:
  with torch.no_grad():
    pred = model(images)

  # softmax 활성화, dim=1 열 운행을 하면서 1차원 벡터로 변경
  pred = torch.argmax(pred, dim=1)

  for i in range(500):
    if(pred[i] == labels[i]):
      correct += 1

print('정확도: ', correct/10000)

728x90
반응형

댓글