-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathconfig.py
54 lines (42 loc) · 1.47 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from models import *
model_zoo = ['DCGAN', 'LSGAN', 'WGAN', 'WGAN-GP', 'EBGAN', 'BEGAN', 'DRAGAN', 'CoulombGAN']
def get_model(mtype, name, training):
model = None
if mtype == 'DCGAN':
model = dcgan.DCGAN
elif mtype == 'LSGAN':
model = lsgan.LSGAN
elif mtype == 'WGAN':
model = wgan.WGAN
elif mtype == 'WGAN-GP':
model = wgan_gp.WGAN_GP
elif mtype == 'EBGAN':
model = ebgan.EBGAN
elif mtype == 'BEGAN':
model = began.BEGAN
elif mtype == 'DRAGAN':
model = dragan.DRAGAN
elif mtype == 'COULOMBGAN':
model = coulombgan.CoulombGAN
else:
assert False, mtype + ' is not in the model zoo'
assert model, mtype + ' is work in progress'
return model(name=name, training=training)
def get_dataset(dataset_name):
celebA_64 = './data/celebA_tfrecords/*.tfrecord'
celebA_128 = './data/celebA_128_tfrecords/*.tfrecord'
lsun_bedroom_128 = './data/lsun/bedroom_128_tfrecords/*.tfrecord'
if dataset_name == 'celeba':
path = celebA_128
n_examples = 202599
elif dataset_name == 'lsun':
path = lsun_bedroom_128
n_examples = 3033042
else:
raise ValueError('{} is does not supported. dataset must be celeba or lsun.'.format(dataset_name))
return path, n_examples
def pprint_args(FLAGS):
print("\nParameters:")
for attr, value in sorted(vars(FLAGS).items()):
print("{}={}".format(attr.upper(), value))
print("")