728x90
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from keras.engine.training import optimizer
# mnist 데이터셋 불러오기
tf.keras.datasets.mnist.load_data()
# train, test 데이터 나누기
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 데이터값 0~1 사이로 바꾸기
x_train = x_train / 255.
x_test = x_test / 255.
# 1. Sequential 모델
medel = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=x_train.shape[1:]),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
# 2. Functional 모델
inputs = Input(shape=x_train.shape[1:], name="inputlayer")
hidden1 = Flatten(input_shape=x_train.shape[1:])(inputs)
hidden2 = Dense(256, activation='relu')(hidden1)
hidden3 = Dense(16, activation='sigmoid')(hidden2)
output = Dense(10, activation="softmax", name="Outputlayer")(hidden3)
# 모델 저장
model = Model(inputs=inputs, outputs=output)
# 모델 요약 (일반)
model.summary()
# 모델 요약 (이미지)
tf.keras.utils.plot_model(model)
# 모델 compile(편집)
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy", # sparse_ => 원핫인코딩해줌
# loss="categorical_crossentropy",
metrics=["accuracy"]
)
# callbacks.EarlyStopping => 더 이상 성능 나아지지 않는다 싶으면 훈련 끝냄
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
# monitor="val_accuracy", mode="max",
# monitor="val_loss", mode="min"
monitor="val_accuracy",
mode = "auto",
patience=10,
restore_best_weights=True,
verbose=1
)
# 모델 fit
history = model.fit(
x_train,
y_train,
validation_split = .25, # 검증용데이터 비율
callbacks=[early_stopping_callback], # 콜백
batch_size=1024, # 배치사이즈만큼씩 돌리겠다
epochs=100, # 전체 데이터를 에포크회 학습하겠다
verbose=2
)
# 모델 평가
loss, acc = model.evaluate(x_test, y_test) # verbose=0 or 1 or 2
loss, acc
history.history
# history => 매 epoch 마다 저장되어있는 학습이력
# loss : 훈련 손실값
# acc : 훈련 정확도
# val_loss : 검증 손실값
# val_acc : 검증 정확도
# history 시각화
length = len(history.history["accuracy"])
plt.plot(range(length), history.history["accuracy"])
plt.plot(range(length), history.history["val_accuracy"])
plt.show()
728x90
'#02.천재교육 빅데이터 > +14.딥러닝_심화' 카테고리의 다른 글
| 230524 API Server (딥러닝 모델 만들고 serving) (0) | 2023.05.24 |
|---|---|
| 230523 CNN, Lightening (0) | 2023.05.23 |