1. 저장하기

import tensorflow.compat.v1 as tf
import numpy as np
import gym
from collections import deque
import random
import os

tf.disable_v2_behavior()


def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print ('Error: Creating directory. ' +  directory)


sess = tf.Session()


DQNmain = DQNet(sess, input_size, output_size, "DQNMain")
modelSaver= tf.train.Saver(var_list= tf.global_variables(), allow_empty=False)
sess.run( tf.global_variables_initializer())

.....
.....

여러 코드 구현

.....
.....


createFolder('my_test_model')
modelSaver.save(sess, 'my_test_model/mymodel', write_meta_graph = False, global_step=global_step)

 

 

 

2. 불러다 쓰기

sess = tf.Session()

# 네트워크 구성
DQNmain = DQNet(sess, input_size, output_size, "DQNMain")

# modelsaver
modelSaver = tf.train.Saver(var_list = tf.global_variables(), allow_empty=False)

# 저장된 checkpoint 있는지 체크..
# 있으면 load.. 없으면 에러
ckpt = tf.train.get_checkpoint_state('my_test_model/')
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
    modelSaver.restore(sess, ckpt.model_checkpoint_path)
else:
    print("Model data not found...")
    exit()

 

+ Recent posts