使用华为云Ascend910MNIST上面训练LeNet网络(本地Win10CPU版)

网友投稿 833 2022-05-30

1 使用华为云Ascend910在MNIST上面训练LeNet网络(本地Win10CPU版)

使用华为云Ascend910在MNIST上面训练LeNet网络,上传loss截图和推理精度截图

MindSpore官网:https://www.mindspore.cn/

开源地址:https://gitee.com/mindspore

课程案例:https://gitee.com/mindspore/course

1.1 LeNet

paper: Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278-2324, November 1998 https://ieeexplore.ieee.org/document/726791

data: http://yann.lecun.com/exdb/mnist/

MindSpore实现:https://gitee.com/mindspore/course/tree/master/lenet5

1.2 环境配置

Python3.7.5 : https://www.python.org/downloads/release/python-375/

MindSpore 1.2.1 : https://www.mindspore.cn/install

# 本地 CPU windows10 mkvirtualenv管理 mkvirtualenv ms121 -p C:\MySoft\Python37\python.exe workon ms121 pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.2.1/MindSpore/cpu/windows_x64/mindspore-1.2.1-cp37-cp37m-win_amd64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple # 服务器 GPU Cuda10.1 anaconda管理 conda create -n ms121 python=3.7.5 source activate ms121 pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.2.1/MindSpore/gpu/ubuntu_x86/cuda-10.1/mindspore_gpu-1.2.1-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple # 验证是否安装成功 python -c "import mindspore as mp;print(mp.__version__)" # 1.2.1 # 其他包 pip install requests

1.3 项目代码

将之前课程代码整理了一下

models.lenet.py

import mindspore.nn as nn from mindspore.common.initializer import Normal # LeNet5 网络定义 class LeNet5(nn.Cell): """Lenet network structure. Args: num_class: class num_channel: chanel x: img """ # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x

utils.download.py

数据集文件下载和解压

"""download file and unzip""" import os import sys import requests import gzip # 解压zip文件 def unzipfile(gzip_path): """unzip dataset file Args: gzip_path: dataset file path """ open_file = open(gzip_path.replace('.gz', ''), 'wb') gz_file = gzip.GzipFile(gzip_path) open_file.write(gz_file.read()) gz_file.close() # 下载数据集文件,显示进度 def download_progress(url, file_name): """download mnist dataset Args: url: download url file_name: dataset name """ res = requests.get(url, stream=True, verify=False) # get mnist dataset size total_size = int(res.headers["Content-Length"]) temp_size = 0 with open(file_name, "wb+") as f: for chunk in res.iter_content(chunk_size=1024): temp_size += len(chunk) f.write(chunk) f.flush() done = int(100 * temp_size / total_size) # show download progress 下载进度 sys.stdout.write("\r[{}{}] {:.2f}%".format("█" * done, " " * (100 - done), 100 * temp_size / total_size)) sys.stdout.flush() print("\n============== {} is already ==============".format(file_name)) unzipfile(file_name) os.remove(file_name)

utils.data_mnist.py

数据集下载和划分

"""download MNIST dataset""" import os from urllib.parse import urlparse import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore import dtype as mstype import mindspore.dataset as ds from .download import download_progress # 下载数据集 从原始网站下载 def download_dataset(mnist_path="./data/MNIST"): """Download the dataset from http://yann.lecun.com/exdb/mnist/.""" print("************** Downloading the MNIST dataset **************") train_path = mnist_path + "/train/" test_path = mnist_path + "/test/" # 创建文件目录 train_path_check = os.path.exists(train_path) test_path_check = os.path.exists(test_path) if not train_path_check and not test_path_check: os.makedirs(train_path) os.makedirs(test_path) train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"} test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"} # 下载训练数据集 for url in train_url: url_parse = urlparse(url) # split the file name from url 获取文件名 file_name = os.path.join(train_path, url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz', '')): download_progress(url, file_name) # 下载测试数据集 for url in test_url: url_parse = urlparse(url) # split the file name from url 获取文件名 file_name = os.path.join(test_path, url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz', '')): download_progress(url, file_name) # 数据集划分 train or test def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path: Data path batch_size: The number of data records in each group repeat_size: The number of replicated data records num_parallel_workers: The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define operation parameters resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images rescale_op = CV.Rescale(rescale, shift) # rescale images hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network # apply map operations on images mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds

train_lenet.py

""" # MindSpore Hello World! MNIST手写数字识别 MindSpore进行手写数字识别,LeNet5模型 # https://gitee.com/mindspore/docs/tree/r1.2/tutorials/tutorial_code/lenet # 官方文档:https://www.mindspore.cn/ """ import os import argparse import mindspore.nn as nn from mindspore import context, Model, load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.nn.metrics import Accuracy from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from utils.data_mnist import download_dataset, create_dataset from models.lenet import LeNet5 # 训练 def train_net(network_model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode): """Define the training method.""" print("============== Starting Training ==============") # load training dataset 数据集划分 ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size) network_model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) # 测试 def test_net(network, network_model, data_path): """Define the evaluation method.""" print("============== Starting Testing ==============") # load the saved model for evaluation 加载训练好的模型 param_dict = load_checkpoint("./tmpmodel/checkpoint_lenet-1_1875.ckpt") # load parameter to the network 加餐参数 load_param_into_net(network, param_dict) # load testing dataset 测试数据集 ds_eval = create_dataset(os.path.join(data_path, "test")) acc = network_model.eval(ds_eval, dataset_sink_mode=False) print("============== Accuracy:{} ==============".format(acc)) if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) dataset_sink_mode = not args.device_target == "CPU" # download mnist dataset 数据集下载 mnist_path = "./data/MNIST" download_dataset(mnist_path) # 参数设置 # learning rate setting lr = 0.01 momentum = 0.9 dataset_size = 1 train_epoch = 1 # create the network 创建LeNet网络 net = LeNet5() # define the optimizer 优化器 net_opt = nn.Momentum(net.trainable_params(), lr, momentum) # define the loss function 交叉熵损失函数 net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 保存训练好的模型 # save the network model and parameters for subsequence fine-tuning config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", directory='./tmpmodel', config=config_ck) # 训练 # group layers into an object with training and evaluation features model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(model, train_epoch, mnist_path, dataset_size, ckpoint, dataset_sink_mode) # 测试 test_net(net, model, mnist_path)

运行

python train_lenet.py

运行结果

使用华为云Ascend910在MNIST上面训练LeNet网络(本地Win10CPU版)

ACC: 0.9692 Loss:0.1544

Windows 机器学习 网络

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Nav2极简笔记01-安装与试用
下一篇:前端代码接入单元测试
相关文章