1 Star 0 Fork 38

清泉之边/quiz-w7-2-densenet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
nets_factory.py 6.46 KB
一键复制 编辑 原始数据 按行查看 历史
dwSun 提交于 2017-12-13 17:27 . add densenet
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a factory for building various models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from nets import alexnet
from nets import cifarnet
from nets import inception
from nets import lenet
from nets import mobilenet_v1
from nets import overfeat
from nets import resnet_v1
from nets import resnet_v2
from nets import vgg
from nets import densenet
from nets.nasnet import nasnet
slim = tf.contrib.slim
networks_map = {'alexnet_v2': alexnet.alexnet_v2,
'cifarnet': cifarnet.cifarnet,
'overfeat': overfeat.overfeat,
'vgg_a': vgg.vgg_a,
'vgg_16': vgg.vgg_16,
'vgg_19': vgg.vgg_19,
'inception_v1': inception.inception_v1,
'inception_v2': inception.inception_v2,
'inception_v3': inception.inception_v3,
'inception_v4': inception.inception_v4,
'inception_resnet_v2': inception.inception_resnet_v2,
'lenet': lenet.lenet,
'resnet_v1_50': resnet_v1.resnet_v1_50,
'resnet_v1_101': resnet_v1.resnet_v1_101,
'resnet_v1_152': resnet_v1.resnet_v1_152,
'resnet_v1_200': resnet_v1.resnet_v1_200,
'resnet_v2_50': resnet_v2.resnet_v2_50,
'resnet_v2_101': resnet_v2.resnet_v2_101,
'resnet_v2_152': resnet_v2.resnet_v2_152,
'resnet_v2_200': resnet_v2.resnet_v2_200,
'mobilenet_v1': mobilenet_v1.mobilenet_v1,
'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075,
'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
'nasnet_cifar': nasnet.build_nasnet_cifar,
'nasnet_mobile': nasnet.build_nasnet_mobile,
'nasnet_large': nasnet.build_nasnet_large,
'densenet': densenet.densenet,
}
arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
'cifarnet': cifarnet.cifarnet_arg_scope,
'overfeat': overfeat.overfeat_arg_scope,
'vgg_a': vgg.vgg_arg_scope,
'vgg_16': vgg.vgg_arg_scope,
'vgg_19': vgg.vgg_arg_scope,
'inception_v1': inception.inception_v3_arg_scope,
'inception_v2': inception.inception_v3_arg_scope,
'inception_v3': inception.inception_v3_arg_scope,
'inception_v4': inception.inception_v4_arg_scope,
'inception_resnet_v2':
inception.inception_resnet_v2_arg_scope,
'lenet': lenet.lenet_arg_scope,
'resnet_v1_50': resnet_v1.resnet_arg_scope,
'resnet_v1_101': resnet_v1.resnet_arg_scope,
'resnet_v1_152': resnet_v1.resnet_arg_scope,
'resnet_v1_200': resnet_v1.resnet_arg_scope,
'resnet_v2_50': resnet_v2.resnet_arg_scope,
'resnet_v2_101': resnet_v2.resnet_arg_scope,
'resnet_v2_152': resnet_v2.resnet_arg_scope,
'resnet_v2_200': resnet_v2.resnet_arg_scope,
'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope,
'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope,
'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope,
'nasnet_cifar': nasnet.nasnet_cifar_arg_scope,
'nasnet_mobile': nasnet.nasnet_mobile_arg_scope,
'nasnet_large': nasnet.nasnet_large_arg_scope,
'densenet': densenet.densenet_arg_scope,
}
def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
"""Returns a network_fn such as `logits, end_points = network_fn(images)`.
Args:
name: The name of the network.
num_classes: The number of classes to use for classification. If 0 or None,
the logits layer is omitted and its input features are returned instead.
weight_decay: The l2 coefficient for the model weights.
is_training: `True` if the model is being used for training and `False`
otherwise.
Returns:
network_fn: A function that applies the model to a batch of images. It has
the following signature:
net, end_points = network_fn(images)
The `images` input is a tensor of shape [batch_size, height, width, 3]
with height = width = network_fn.default_image_size. (The permissibility
and treatment of other sizes depends on the network_fn.)
The returned `end_points` are a dictionary of intermediate activations.
The returned `net` is the topmost layer, depending on `num_classes`:
If `num_classes` was a non-zero integer, `net` is a logits tensor
of shape [batch_size, num_classes].
If `num_classes` was 0 or `None`, `net` is a tensor with the input
to the logits layer of shape [batch_size, 1, 1, num_features] or
[batch_size, num_features]. Dropout has not been applied to this
(even if the network's original classification does); it remains for
the caller to do this or not.
Raises:
ValueError: If network `name` is not recognized.
"""
if name not in networks_map:
raise ValueError('Name of network unknown %s' % name)
func = networks_map[name]
@functools.wraps(func)
def network_fn(images, **kwargs):
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
with slim.arg_scope(arg_scope):
return func(images, num_classes, is_training=is_training, **kwargs)
if hasattr(func, 'default_image_size'):
network_fn.default_image_size = func.default_image_size
return network_fn
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/skyter/quiz-w7-2-densenet.git
git@gitee.com:skyter/quiz-w7-2-densenet.git
skyter
quiz-w7-2-densenet
quiz-w7-2-densenet
master

搜索帮助