728x90
1. 모델 전체 저장하기
import tensorflow as tf
import numpy as np
(trainX, trainY), (testX, testY) = tf.keras.datasets.fashion_mnist.load_data()
trainX = trainX / 255.0
testX = testX / 255.0
trainX = trainX.reshape( (trainX.shape[0], 28,28,1) )
testX = testX.reshape( (testX.shape[0], 28,28,1) )
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit(trainX, trainY, validation_data=(testX, testY), epochs=3)
medel.summary()
# 모델 전체 저장하기기
model.save('새폴더/model1')
# 저장된 모델 불러오기
from keras.optimizers import optimizer
불러온모델 = tf.keras.models.load_model('새폴더/model1')
불러온모델.summary()
# acc 그지 같이 나오면 아래코드 다시 실행
# 불러온모델.compile(loss='sparse_categorical_crossentropy', optimizer="adam", metrics='sparse_categorical_accuracy')
불러온모델.evaluate(testX, testY)
2. w값만 저장하기
import tensorflow as tf
import numpy as np
(trainX, trainY), (testX, testY) = tf.keras.datasets.fashion_mnist.load_data()
trainX = trainX / 255.0
testX = testX / 255.0
trainX = trainX.reshape( (trainX.shape[0], 28,28,1) )
testX = testX.reshape( (testX.shape[0], 28,28,1) )
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
# ModelCheckpoint 로 모델의 w값만 저장하기
콜백함수 = tf.keras.callbacks.ModelCheckpoint(
filepath='체크포인트/mnist',
save_weights_only=True,
save_freq='epoch'
)
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit(trainX, trainY, validation_data=(testX, testY), epochs=3, callbacks=[콜백함수]) # callback
medel.summary()
# w값만 저장해놨으면 모델 만들고 w값(checkpoint파일) 로드
model2 = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model2.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model2.fit(trainX, trainY, validation_data=(testX, testY), epochs=3, callbacks=[콜백함수]) # callback
model2.load_weights('체크포인트/mnist') # 아까 저장한 w값 로드
728x90
'#05.코딩애플 > +01.딥러닝' 카테고리의 다른 글
| [코딩애플] 개 vs 고양이 이미지 분류하기 (0) | 2023.04.22 |
|---|---|
| [코딩애플] 개 vs 고양이 데이터 분류하기 (0) | 2023.04.19 |
| [코딩애플] CNN (0) | 2023.04.17 |
| [코딩애플] 딥러닝이란? (0) | 2023.04.01 |