Mxnet如何实现mnist预测训练,模型导出为onnx,再导入预测?

2026-05-19 18:291阅读0评论SEO问题
  • 内容介绍
  • 文章标签
  • 相关推荐

本文共计3935个文字,预计阅读时间需要16分钟。

Mxnet如何实现mnist预测训练,模型导出为onnx,再导入预测?

需要做什么?+ 方便广泛 + 烟酒生 + 研究生、+ 人工智能工程师 + 算法工程师快速使用mxnet,特写此文章。默认用户已有基本的深度学习、数据集概念。+ 系统环境 + Python 3.7.4 + mxnet 1.+

需要做点什么

方便广大烟酒生研究生、人工智障炼丹师算法工程师快速使用mxnet,所以特写此文章,默认使用者已有基本的深度学习概念、数据集概念。

系统环境

python 3.7.4
mxnet 1.9.0
mxnet-cu112 1.9.0
onnx 1.9.0
onnxruntime-gpu 1.9.0

数据准备

MNIST数据集csv文件是一个42000x785的矩阵
42000表示有42000张图片
785中第一列是图片的类别(0,1,2,..,9),第二列到最后一列是图片数据向量 (28x28的图片张成784的向量), 数据集长这个样子:

Mxnet如何实现mnist预测训练,模型导出为onnx,再导入预测?

1 0 0 0 0 0 0 0 0 0 ..
0 0 0 0 0 0 0 0 0 0 ..
1 0 0 0 0 0 0 0 0 0 ..

1. 导入需要的包

import time import copy import onnx import logging import platform import mxnet as mx import numpy as np import pandas as pd import onnxruntime as ort from sklearn.metrics import accuracy_score logger = logging.getLogger() logger.setLevel(logging.DEBUG) # Mxnet Chcek if platform.system().lower() != 'windows': print(mx.runtime.feature_list()) print(mx.context.num_gpus()) a = mx.nd.ones((2, 3), mx.cpu()) b = a * 2 + 1 print(b)

运行输出

[✔ CUDA, ✔ CUDNN, ✔ NCCL, ✔ CUDA_RTC, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✖ CPU_SSE4_1, ✖ CPU_SSE4_2, ✖ CPU_SSE4A, ✖ CPU_AVX, ✖ CPU_AVX2, ✔ OPENMP, ✖ SSE, ✖ F16C, ✖ JEMALLOC, ✔ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✖ BLAS_APPLE, ✔ LAPACK, ✔ MKLDNN, ✔ OPENCV, ✖ CAFFE, ✖ PROFILER, ✔ DIST_KVSTORE, ✖ CXX14, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✖ DEBUG, ✖ TVM_OP] 1 [[3. 3. 3.] [3. 3. 3.]] <NDArray 2x3 @cpu(0)> 2. 参数准备

N_EPOCH = 1 N_BATCH = 32 N_BATCH_NUM = 900 S_DATA_PATH = r"mnist_train.csv" S_MODEL_PATH = r"mxnet_cnn" S_SYM_PATH = './mxnet_cnn-symbol.json' S_PARAMS_PATH = './mxnet_cnn-0001.params' S_ONNX_MODEL_PATH = './mxnet_cnn.onnx' S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cuda", 0, "cuda:0" # S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cpu", 0, "cpu" CTX = mx.cpu() if S_DEVICE == "cpu" else mx.gpu(N_DEVICE_ID) B_IS_UNIX = True 3. 读取数据

df = pd.read_csv(S_DATA_PATH, header=None) print(df.shape) np_mat = np.array(df) print(np_mat.shape) X = np_mat[:, 1:] Y = np_mat[:, 0] X = X.astype(np.float32) / 255 X_train = X[:N_BATCH * N_BATCH_NUM] X_test = X[N_BATCH * N_BATCH_NUM:] Y_train = Y[:N_BATCH * N_BATCH_NUM] Y_test = Y[N_BATCH * N_BATCH_NUM:] X_train = X_train.reshape(X_train.shape[0], 1, 28, 28) X_test = X_test.reshape(X_test.shape[0], 1, 28, 28) print(X_train.shape) print(Y_train.shape) print(X_test.shape) print(Y_test.shape) train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=N_BATCH) test_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=N_BATCH) test_iter_2 = copy.copy(test_iter)

运行输出

(37800, 785) (37800, 785) (28800, 1, 28, 28) (28800,) (9000, 1, 28, 28) (9000,) 4. 模型构建

net = mx.gluon.nn.HybridSequential() with net.name_scope(): net.add(mx.gluon.nn.Conv2D(channels=32, kernel_size=3, activation='relu')) # bx28x28 ==> net.add(mx.gluon.nn.MaxPool2D(pool_size=2, strides=2)) net.add(mx.gluon.nn.Flatten()) net.add(mx.gluon.nn.Dense(128, activation="relu")) net.add(mx.gluon.nn.Dense(10)) net.hybridize() print(net) net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) softmax_cross_entropy = mx.gluon.loss.SoftmaxCrossEntropyLoss() trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .001})

运行输出

HybridSequential( (0): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), Activation(relu)) (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (2): Flatten (3): Dense(None -> 128, Activation(relu)) (4): Dense(None -> 10, linear) ) 5. 模型训练

for epoch in range(N_EPOCH): for batch_num, itr in enumerate(train_iter): data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) with mx.autograd.record(): output = net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss loss.backward() trainer.step(data.shape[0]) if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print() # print("Epoch: %d; Batch %d; Loss %s; ACC %f" % # (epoch, batch_num, str(curr_loss), f_acc)) print()

运行输出

Epoch: 0; Batch 0; ACC 0.09375 loss: [2.2868602] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.875 loss: [0.512461] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.90625 loss: [0.43415746] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 0.84375 loss: [0.3854709] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 1.0 loss: [0.04192135] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 0.90625 loss: [0.21156572] <NDArray 1 @gpu(0)> Epoch: 0; Batch 300; ACC 0.9375 loss: [0.15938525] <NDArray 1 @gpu(0)> Epoch: 0; Batch 350; ACC 1.0 loss: [0.0379494] <NDArray 1 @gpu(0)> Epoch: 0; Batch 400; ACC 0.96875 loss: [0.17104594] <NDArray 1 @gpu(0)> Epoch: 0; Batch 450; ACC 0.96875 loss: [0.12192786] <NDArray 1 @gpu(0)> Epoch: 0; Batch 500; ACC 0.96875 loss: [0.09210478] <NDArray 1 @gpu(0)> Epoch: 0; Batch 550; ACC 0.9375 loss: [0.13728428] <NDArray 1 @gpu(0)> Epoch: 0; Batch 600; ACC 0.96875 loss: [0.0762211] <NDArray 1 @gpu(0)> Epoch: 0; Batch 650; ACC 0.96875 loss: [0.12162409] <NDArray 1 @gpu(0)> Epoch: 0; Batch 700; ACC 1.0 loss: [0.04334489] <NDArray 1 @gpu(0)> Epoch: 0; Batch 750; ACC 1.0 loss: [0.06458903] <NDArray 1 @gpu(0)> Epoch: 0; Batch 800; ACC 0.96875 loss: [0.07410634] <NDArray 1 @gpu(0)> Epoch: 0; Batch 850; ACC 0.96875 loss: [0.14233188] <NDArray 1 @gpu(0)> 6.模型预测

for batch_num, itr in enumerate(test_iter_2): data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) output = net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print()

运行输出

Epoch: 0; Batch 0; ACC 0.96875 loss: [0.22968824] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.96875 loss: [0.05668993] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.96875 loss: [0.08171713] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 1.0 loss: [0.02264522] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 0.96875 loss: [0.080383] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 1.0 loss: [0.03774196] <NDArray 1 @gpu(0)> 7.模型保存

net.export(S_MODEL_PATH, epoch=N_EPOCH) # 保存模型结构和全部参数 8.模型加载和加载模型使用

print("load net and do test") load_net = mx.gluon.nn.SymbolBlock.imports(S_SYM_PATH, ['data'], S_PARAMS_PATH, ctx=CTX) # 加载模型 print("load ok") for batch_num, itr in enumerate(test_iter): # Test data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) output = load_net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print() print("finish")

运行输出

load net and do test load ok Epoch: 0; Batch 0; ACC 0.96875 loss: [0.22968824] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.96875 loss: [0.05668993] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.96875 loss: [0.08171713] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 1.0 loss: [0.02264522] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 0.96875 loss: [0.080383] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 1.0 loss: [0.03774196] <NDArray 1 @gpu(0)> finish 9.导出ONNX

if platform.system().lower() != 'windows': mx.onnx.export_model(S_SYM_PATH, S_PARAMS_PATH, [(32, 1, 28, 28)], [np.float32], S_ONNX_MODEL_PATH, verbose=True, dynamic=True)

运行输出

INFO:root:Converting json and weight file to sym and params INFO:root:Converting idx: 0, op: null, name: data INFO:root:Converting idx: 1, op: null, name: hybridsequential0_conv0_weight INFO:root:Converting idx: 2, op: null, name: hybridsequential0_conv0_bias INFO:root:Converting idx: 3, op: Convolution, name: hybridsequential0_conv0_fwd INFO:root:Converting idx: 4, op: Activation, name: hybridsequential0_conv0_relu_fwd INFO:root:Converting idx: 5, op: Pooling, name: hybridsequential0_pool0_fwd INFO:root:Converting idx: 6, op: Flatten, name: hybridsequential0_flatten0_flatten0 INFO:root:Converting idx: 7, op: null, name: hybridsequential0_dense0_weight INFO:root:Converting idx: 8, op: null, name: hybridsequential0_dense0_bias INFO:root:Converting idx: 9, op: FullyConnected, name: hybridsequential0_dense0_fwd INFO:root:Converting idx: 10, op: Activation, name: hybridsequential0_dense0_relu_fwd INFO:root:Converting idx: 11, op: null, name: hybridsequential0_dense1_weight INFO:root:Converting idx: 12, op: null, name: hybridsequential0_dense1_bias INFO:root:Converting idx: 13, op: FullyConnected, name: hybridsequential0_dense1_fwd INFO:root:Output node is: hybridsequential0_dense1_fwd INFO:root:Input shape of the model [(32, 1, 28, 28)] INFO:root:Exported ONNX file ./mxnet_cnn.onnx saved to disk 10. 加载ONNX并运行

if platform.system().lower() != 'windows': model = onnx.load(S_ONNX_MODEL_PATH) print(onnx.checker.check_model(model)) # Check that the model is well formed # print(onnx.helper.printable_graph(model.graph)) # Print a human readable representation of the graph ls_input_name, ls_output_name = [input.name for input in model.graph.input], [output.name for output in model.graph.output] print("input name ", ls_input_name) print("output name ", ls_output_name) s_input_name = ls_input_name[0] x_input = X_train[:N_BATCH*2, :, :, :].astype(np.float32) ort_val = ort.OrtValue.ortvalue_from_numpy(x_input, S_DEVICE, N_DEVICE_ID) print("val device ", ort_val.device_name()) print("val shape ", ort_val.shape()) print("val data type ", ort_val.data_type()) print("is_tensor ", ort_val.is_tensor()) print("array_equal ", np.array_equal(ort_val.numpy(), x_input)) providers = 'CUDAExecutionProvider' if S_DEVICE == "cuda" else 'CPUExecutionProvider' print("providers ", providers) ort_session = ort.InferenceSession(S_ONNX_MODEL_PATH, providers=[providers]) # gpu运行 ort_session.set_providers([providers]) outputs = ort_session.run(None, {s_input_name: ort_val}) print("sess env ", ort_session.get_providers()) print(type(outputs)) print(outputs[0]) ''' For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. '''

运行输出

None input name ['data', 'hybridsequential0_conv0_weight', 'hybridsequential0_conv0_bias', 'hybridsequential0_dense0_weight', 'hybridsequential0_dense0_bias', 'hybridsequential0_dense1_weight', 'hybridsequential0_dense1_bias'] output name ['hybridsequential0_dense1_fwd'] val device cuda val shape [64, 1, 28, 28] val data type tensor(float) is_tensor True array_equal True providers CUDAExecutionProvider sess env ['CUDAExecutionProvider', 'CPUExecutionProvider'] <class 'list'> [[-2.69336128e+00 8.42524242e+00 -3.34120363e-01 -1.17912292e+00 3.82278800e-01 -3.60794234e+00 3.58125120e-01 -2.58064723e+00 1.55215383e+00 -2.03553891e+00] [ 1.02665892e+01 -6.65782404e+00 -2.04501271e-01 -2.25653172e+00 -6.31941366e+00 1.13084137e+00 -3.83885235e-01 8.22283030e-01 -1.21192622e+00 3.33601260e+00] [-3.27186418e+00 1.00050325e+01 5.39114550e-02 -1.44938648e+00 -9.89762247e-01 -2.09957671e+00 -1.49389958e+00 6.52510405e-01 1.73153889e+00 -1.25597775e+00] [ 5.72116375e-01 -3.36192799e+00 -6.68362260e-01 -2.81247520e+00 8.36382389e+00 -3.67477946e-02 2.23792076e+00 -2.91093756e-02 -4.56922323e-01 -6.77382052e-01] [ 1.18602552e+01 -5.09683752e+00 4.54203248e-01 -2.55723000e+00 -8.68753910e+00 6.96948707e-01 -1.50591761e-01 -3.62227589e-01 9.83437955e-01 7.46711075e-01] [ 7.33289337e+00 -6.65414715e+00 1.57180679e+00 -2.62657452e+00 4.11511570e-01 -1.35336161e+00 -1.40558392e-01 3.81030589e-01 1.73799121e+00 8.02671254e-01] [-3.02898431e+00 1.26861107e+00 -2.04946566e+00 -2.52499342e-01 -2.73597687e-01 -3.01714039e+00 -7.10914516e+00 1.10452967e+01 -5.82177579e-01 1.86712158e+00] [-7.78098392e+00 -6.01984358e+00 1.23355007e+00 1.18682652e+01 -9.83472538e+00 8.27242088e+00 -1.02135544e+01 3.95661980e-01 6.63226461e+00 3.33681512e+00] [-2.72245955e+00 -6.74849796e+00 -6.24665642e+00 3.11165476e+00 -4.71174330e-01 1.22390661e+01 -1.23519528e+00 -1.24356663e+00 1.26693976e+00 5.81862879e+00] [-5.65229607e+00 -1.25138938e+00 3.68380380e+00 1.24947300e+01 -8.21508980e+00 1.61641145e+00 -8.01925087e+00 8.37018967e-01 -2.64613247e+00 7.92313635e-01] [-3.73405719e+00 -3.41621947e+00 -7.94842839e-01 4.55352879e+00 -2.28238964e+00 1.88887548e+00 -5.84129477e+00 6.03430390e-01 1.05920439e+01 2.25430655e+00] [-5.44103146e+00 -5.48421431e+00 -3.62234282e+00 1.20194650e+00 3.48899674e+00 1.50794566e+00 -6.30612850e+00 4.01568127e+00 1.61318648e+00 9.87832165e+00] [-3.34073186e+00 8.10987663e+00 -6.43497527e-01 -1.64372277e+00 -4.42907363e-01 -1.46176386e+00 -8.56327295e-01 5.20323329e-02 1.73289025e+00 -8.17061365e-01] [-6.88457203e+00 1.38391244e+00 1.33096969e+00 1.28132534e+01 -6.20939922e+00 1.48244214e+00 -6.59804583e+00 -1.38118923e+00 4.26289368e+00 -1.22962976e+00] [-6.09051991e+00 -3.15275192e+00 1.79273260e+00 9.92699528e+00 -5.97349882e+00 3.68225765e+00 -6.47421646e+00 -1.99264419e+00 2.15714622e+00 2.32836318e+00] [-3.25946307e+00 8.14360428e+00 -1.00535810e+00 -2.37552500e+00 2.38139248e+00 -2.92597318e+00 -1.54173911e+00 2.25682306e+00 -2.83430189e-01 -1.33554244e+00] [-2.99147058e+00 3.86941671e+00 8.82810593e+00 2.20121431e+00 -8.40485859e+00 -8.66728902e-01 -5.97998762e+00 -5.21699572e+00 5.80638123e+00 -2.57314467e+00] [ 8.64277363e+00 -4.99241495e+00 2.84688592e+00 -4.15350378e-01 -1.87728360e-01 -2.40291572e+00 4.42544132e-01 -4.54446167e-01 -1.88113344e+00 -1.23334014e+00] [-2.00169897e+00 -2.65497804e+00 1.18750989e+00 9.70900059e-01 -4.53840446e+00 -2.65584946e+00 -8.23472023e+00 9.93836498e+00 -5.57100773e-01 3.42955470e+00] [-3.57249069e+00 -5.03176594e+00 -1.79369414e+00 -5.03321826e-01 -1.97100627e+00 9.01608944e+00 6.62497377e+00 -5.48222637e+00 6.09256268e+00 -4.71334040e-01] [-5.27715540e+00 -7.84428477e-01 -6.26944721e-01 3.87298250e+00 -1.88836837e+00 1.15252662e+00 -2.98473048e+00 -3.10233998e+00 9.71112537e+00 3.10839200e+00] [-9.50223565e-01 -6.47654009e+00 2.26750326e+00 1.95419586e+00 1.68217969e+00 1.66003108e+00 9.82697105e+00 -9.94868219e-01 -2.03924966e+00 -1.88321277e-01] [-3.11575246e+00 3.43664408e+00 1.19877796e+01 4.36916590e+00 -1.17812777e+01 -1.69431508e+00 -5.82668829e+00 -5.09948444e+00 4.15738583e+00 -4.30461359e+00] [ 9.72177792e+00 -5.31352401e-01 -1.21784186e+00 -1.07392669e+00 -7.11223555e+00 1.67838800e+00 1.01826215e+00 -8.88240516e-01 6.95959151e-01 2.38748863e-01] [-2.06619406e+00 1.86608231e+00 1.12100420e+01 4.22539425e+00 -1.21493711e+01 -4.57662535e+00 -6.88935089e+00 -9.81215835e-01 3.87611055e+00 -3.28470826e+00] [-6.73031902e+00 -2.54390073e+00 -1.10151446e+00 1.51524162e+01 -1.10052080e+01 6.60323954e+00 -7.94388771e+00 3.31939721e+00 -1.40840662e+00 2.65730071e+00] [-1.96954179e+00 -1.13817227e+00 9.40351069e-01 -1.75684047e+00 3.60373807e+00 2.01377797e+00 1.00558109e+01 -1.10547984e+00 5.17374456e-01 -3.94047165e+00] [-5.81787634e+00 -1.20211565e+00 -3.53216052e+00 1.17569458e+00 4.21314144e+00 -2.53644252e+00 -7.64466667e+00 4.19782829e+00 4.28840429e-01 1.04579344e+01] [-4.20310974e+00 -3.19272375e+00 -4.62792778e+00 2.71683741e+00 4.43899345e+00 3.31357956e-01 -6.24839544e+00 3.80388188e+00 -1.22620119e-02 9.65024757e+00] [-8.26945066e-01 -5.25947523e+00 3.72772887e-02 2.30585241e+00 -4.95726252e+00 -1.19987357e+00 -1.20395079e+01 1.53253164e+01 -2.10372299e-01 1.89387524e+00] [-5.09596729e+00 -7.76027665e-02 -9.53466833e-01 2.89041376e+00 -1.50858855e+00 2.27854323e+00 -1.95591903e+00 -3.15785193e+00 1.00103540e+01 1.08987451e+00] [-4.01680946e-01 -4.62062168e+00 3.90530303e-02 -1.66790807e+00 5.43311167e+00 -1.78802896e+00 -2.88405848e+00 2.93439984e+00 -2.16558409e+00 8.71198368e+00] [-1.29969406e+00 -3.92871022e+00 -3.82151055e+00 -2.93831253e+00 1.03674269e+01 8.88044477e-01 7.88922787e-01 3.86107159e+00 1.60807288e+00 -3.76913965e-01] [-2.25020099e+00 -8.17249107e+00 -1.82360613e+00 -5.90175152e-01 4.72407389e+00 9.39436078e-01 -3.85674310e+00 3.95303702e+00 1.83473241e+00 9.13874054e+00] [-3.26617742e+00 -2.91517663e+00 8.37770653e+00 1.61820054e-01 -2.98638320e+00 -2.47211266e+00 5.08574843e-01 4.65608168e+00 2.66001201e+00 -4.67262363e+00] [-2.29874635e+00 7.77097034e+00 1.11359918e+00 -2.06103897e+00 -7.61267126e-01 1.00877440e+00 1.47708499e+00 -1.20483887e+00 1.99922264e+00 -3.81118345e+00] [-6.87821198e+00 -9.18823600e-01 2.16773844e+00 1.07671242e+01 -7.48823595e+00 2.90310860e+00 -1.02075748e+01 -3.83400464e+00 4.76818371e+00 4.06564474e+00] [-2.06487226e+00 8.76828384e+00 1.10449910e+00 -2.29669046e+00 -1.15668392e+00 -2.50351834e+00 -1.69508122e-02 -1.05916834e+00 1.91057479e+00 -2.64592767e+00] [-2.24318981e+00 9.02024174e+00 1.38990092e+00 -2.72154903e+00 1.46101296e-01 -4.43454313e+00 -8.21092844e-01 -2.40808502e-01 3.36577922e-01 -2.63193059e+00] [-4.35961342e+00 -7.74704576e-01 -2.74345660e+00 -3.27951574e+00 1.50971518e+01 -2.80669570e+00 -1.28740633e+00 3.94157290e+00 4.20372874e-01 8.37333024e-01] [-2.80749345e+00 -3.33036280e+00 -1.00865018e+00 4.57633829e+00 -5.03952360e+00 2.93345642e+00 -8.54609489e+00 2.26549125e+00 4.73208952e+00 5.93849993e+00] [-3.31042576e+00 9.97719002e+00 4.38573778e-01 -1.35296178e+00 -1.21057940e+00 -2.46178842e+00 -8.34564090e-01 2.19030753e-01 2.03147411e+00 -1.80211437e+00] [-2.14534068e+00 -2.93023801e+00 4.41405416e-01 -2.29865336e+00 1.47422533e+01 -1.86358702e+00 4.61042017e-01 -6.20108247e-01 -1.36792552e+00 2.14018896e-01] [-2.64241481e+00 -2.28332114e+00 2.01109338e+00 9.67352509e-01 6.09287119e+00 -2.35626236e-01 -3.02941656e+00 4.32772923e+00 -4.63955021e+00 3.73136783e+00] [-4.55847168e+00 1.04014337e+00 9.12987328e+00 2.06433630e+00 -1.67355919e+00 -1.49593079e+00 4.09124941e-01 2.41894865e+00 -6.86871633e-02 -4.42179346e+00] [-6.10608578e-01 -2.73860097e+00 -1.09864855e+00 -5.68899512e-01 2.45831108e+00 5.50326490e+00 1.22601585e+01 -5.23877192e+00 2.11066246e+00 -2.98584485e+00] [-5.96745872e+00 -1.91458237e+00 -3.10774088e-01 1.00216856e+01 -3.81997776e+00 7.14399862e+00 -4.61386251e+00 -5.18248987e+00 4.25162363e+00 1.18878789e-01] [-4.24126434e+00 -9.63249326e-01 1.06391013e+00 4.45316315e+00 -4.47125340e+00 -1.21906054e+00 -9.75789547e+00 1.10335569e+01 -1.17632782e+00 8.78942788e-01] [ 3.76867175e-01 -4.75102758e+00 -2.59345794e+00 -3.96257102e-01 -7.50329159e-03 1.81642962e+00 -6.01041269e+00 9.97849655e+00 -2.57468176e+00 5.00644207e+00] [-2.21995306e+00 -4.23465443e+00 -2.19536662e+00 -3.71420813e+00 1.49460211e+01 -2.73240638e+00 3.03538777e-02 4.29108334e+00 1.21963143e+00 6.85284913e-01] [-1.93794692e+00 -3.39166284e+00 -3.41372967e-01 -2.16144085e-01 1.32588074e-01 -3.83050554e-02 -7.32452822e+00 9.68561840e+00 -1.16044319e+00 3.63913298e+00] [-3.31972694e+00 -4.84879112e+00 -3.41381001e+00 1.93332338e+00 -5.16150045e+00 1.05730495e+01 1.52961123e+00 -4.64916992e+00 4.11477995e+00 4.80105543e+00] [-3.22445488e+00 1.01509037e+01 6.03635088e-02 -1.48001885e+00 -3.28380197e-01 -2.36782789e+00 -8.66441727e-01 6.68077886e-01 1.45576179e+00 -1.95271623e+00] [-5.64618587e+00 -3.71156931e+00 -2.31174397e+00 1.76912701e+00 2.95111752e+00 2.09562635e+00 -5.34609461e+00 -2.59399921e-01 1.10373080e+00 1.10444403e+01] [ 1.32743053e+01 -5.11389780e+00 1.10446036e+00 -5.01595545e+00 -4.87907743e+00 4.46035750e-02 4.57046556e+00 -1.76434004e+00 -5.40824793e-02 -1.01547205e+00] [-4.27589798e+00 1.26832044e+00 6.49948978e+00 -8.06193352e-02 1.34645328e-01 6.92090929e-01 -8.92272711e-01 2.15252757e+00 -9.95365143e-01 -2.46636438e+00] [-5.08349514e+00 -1.79646879e-01 9.57399654e+00 3.35643005e+00 -3.42183971e+00 -2.33653522e+00 -1.98645079e+00 1.50538552e+00 -1.43313253e+00 -9.75638926e-01] [-8.48450565e+00 -4.32870531e+00 -1.54253757e+00 1.61029205e+01 -8.41084957e+00 5.45092726e+00 -1.00705996e+01 2.31331086e+00 1.05400957e-01 5.19563723e+00] [-4.89484310e+00 -4.33120441e+00 -3.71932483e+00 -1.18670315e-01 4.70187807e+00 2.67010808e-01 -5.52650118e+00 5.62736416e+00 5.05499423e-01 8.69004250e+00] [-1.94048929e+00 7.98250961e+00 -3.75457823e-01 -2.09433126e+00 -4.63896915e-02 -2.64650702e+00 -5.73348761e-01 -1.34597674e-01 1.26143038e+00 -1.32178509e+00] [-6.36190224e+00 5.44552183e+00 1.89667892e+00 3.91665459e+00 -1.99860716e+00 2.76620483e+00 3.57649279e+00 -3.53898597e+00 -1.21411644e-01 -2.60307193e+00] [-3.60333633e+00 4.92228556e+00 2.87770915e+00 1.31902504e+00 -4.28756446e-01 -3.29862523e+00 -2.29294825e+00 -1.10349190e+00 3.81862259e+00 -1.23572731e+00] [ 1.06328213e+00 -6.86575174e+00 -3.61938500e+00 1.15000987e+00 -2.32747698e+00 1.32029047e+01 -3.19671059e+00 8.91836137e-02 4.70500660e+00 2.51928687e+00] [ 1.00363417e+01 -8.02793884e+00 9.38569680e-02 -2.78522468e+00 -2.84671545e+00 2.95867848e+00 4.23545599e+00 2.52083421e+00 -3.35666537e+00 1.45630157e+00]] 你甚至不愿意Start的Github

ai_fast_handbook

标签:MNIST预测

本文共计3935个文字,预计阅读时间需要16分钟。

Mxnet如何实现mnist预测训练,模型导出为onnx,再导入预测?

需要做什么?+ 方便广泛 + 烟酒生 + 研究生、+ 人工智能工程师 + 算法工程师快速使用mxnet,特写此文章。默认用户已有基本的深度学习、数据集概念。+ 系统环境 + Python 3.7.4 + mxnet 1.+

需要做点什么

方便广大烟酒生研究生、人工智障炼丹师算法工程师快速使用mxnet,所以特写此文章,默认使用者已有基本的深度学习概念、数据集概念。

系统环境

python 3.7.4
mxnet 1.9.0
mxnet-cu112 1.9.0
onnx 1.9.0
onnxruntime-gpu 1.9.0

数据准备

MNIST数据集csv文件是一个42000x785的矩阵
42000表示有42000张图片
785中第一列是图片的类别(0,1,2,..,9),第二列到最后一列是图片数据向量 (28x28的图片张成784的向量), 数据集长这个样子:

Mxnet如何实现mnist预测训练,模型导出为onnx,再导入预测?

1 0 0 0 0 0 0 0 0 0 ..
0 0 0 0 0 0 0 0 0 0 ..
1 0 0 0 0 0 0 0 0 0 ..

1. 导入需要的包

import time import copy import onnx import logging import platform import mxnet as mx import numpy as np import pandas as pd import onnxruntime as ort from sklearn.metrics import accuracy_score logger = logging.getLogger() logger.setLevel(logging.DEBUG) # Mxnet Chcek if platform.system().lower() != 'windows': print(mx.runtime.feature_list()) print(mx.context.num_gpus()) a = mx.nd.ones((2, 3), mx.cpu()) b = a * 2 + 1 print(b)

运行输出

[✔ CUDA, ✔ CUDNN, ✔ NCCL, ✔ CUDA_RTC, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✖ CPU_SSE4_1, ✖ CPU_SSE4_2, ✖ CPU_SSE4A, ✖ CPU_AVX, ✖ CPU_AVX2, ✔ OPENMP, ✖ SSE, ✖ F16C, ✖ JEMALLOC, ✔ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✖ BLAS_APPLE, ✔ LAPACK, ✔ MKLDNN, ✔ OPENCV, ✖ CAFFE, ✖ PROFILER, ✔ DIST_KVSTORE, ✖ CXX14, ✖ INT64_TENSOR_SIZE, ✔ SIGNAL_HANDLER, ✖ DEBUG, ✖ TVM_OP] 1 [[3. 3. 3.] [3. 3. 3.]] <NDArray 2x3 @cpu(0)> 2. 参数准备

N_EPOCH = 1 N_BATCH = 32 N_BATCH_NUM = 900 S_DATA_PATH = r"mnist_train.csv" S_MODEL_PATH = r"mxnet_cnn" S_SYM_PATH = './mxnet_cnn-symbol.json' S_PARAMS_PATH = './mxnet_cnn-0001.params' S_ONNX_MODEL_PATH = './mxnet_cnn.onnx' S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cuda", 0, "cuda:0" # S_DEVICE, N_DEVICE_ID, S_DEVICE_FULL = "cpu", 0, "cpu" CTX = mx.cpu() if S_DEVICE == "cpu" else mx.gpu(N_DEVICE_ID) B_IS_UNIX = True 3. 读取数据

df = pd.read_csv(S_DATA_PATH, header=None) print(df.shape) np_mat = np.array(df) print(np_mat.shape) X = np_mat[:, 1:] Y = np_mat[:, 0] X = X.astype(np.float32) / 255 X_train = X[:N_BATCH * N_BATCH_NUM] X_test = X[N_BATCH * N_BATCH_NUM:] Y_train = Y[:N_BATCH * N_BATCH_NUM] Y_test = Y[N_BATCH * N_BATCH_NUM:] X_train = X_train.reshape(X_train.shape[0], 1, 28, 28) X_test = X_test.reshape(X_test.shape[0], 1, 28, 28) print(X_train.shape) print(Y_train.shape) print(X_test.shape) print(Y_test.shape) train_iter = mx.io.NDArrayIter(X_train, Y_train, batch_size=N_BATCH) test_iter = mx.io.NDArrayIter(X_test, Y_test, batch_size=N_BATCH) test_iter_2 = copy.copy(test_iter)

运行输出

(37800, 785) (37800, 785) (28800, 1, 28, 28) (28800,) (9000, 1, 28, 28) (9000,) 4. 模型构建

net = mx.gluon.nn.HybridSequential() with net.name_scope(): net.add(mx.gluon.nn.Conv2D(channels=32, kernel_size=3, activation='relu')) # bx28x28 ==> net.add(mx.gluon.nn.MaxPool2D(pool_size=2, strides=2)) net.add(mx.gluon.nn.Flatten()) net.add(mx.gluon.nn.Dense(128, activation="relu")) net.add(mx.gluon.nn.Dense(10)) net.hybridize() print(net) net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) softmax_cross_entropy = mx.gluon.loss.SoftmaxCrossEntropyLoss() trainer = mx.gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .001})

运行输出

HybridSequential( (0): Conv2D(None -> 32, kernel_size=(3, 3), stride=(1, 1), Activation(relu)) (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW) (2): Flatten (3): Dense(None -> 128, Activation(relu)) (4): Dense(None -> 10, linear) ) 5. 模型训练

for epoch in range(N_EPOCH): for batch_num, itr in enumerate(train_iter): data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) with mx.autograd.record(): output = net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss loss.backward() trainer.step(data.shape[0]) if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print() # print("Epoch: %d; Batch %d; Loss %s; ACC %f" % # (epoch, batch_num, str(curr_loss), f_acc)) print()

运行输出

Epoch: 0; Batch 0; ACC 0.09375 loss: [2.2868602] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.875 loss: [0.512461] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.90625 loss: [0.43415746] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 0.84375 loss: [0.3854709] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 1.0 loss: [0.04192135] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 0.90625 loss: [0.21156572] <NDArray 1 @gpu(0)> Epoch: 0; Batch 300; ACC 0.9375 loss: [0.15938525] <NDArray 1 @gpu(0)> Epoch: 0; Batch 350; ACC 1.0 loss: [0.0379494] <NDArray 1 @gpu(0)> Epoch: 0; Batch 400; ACC 0.96875 loss: [0.17104594] <NDArray 1 @gpu(0)> Epoch: 0; Batch 450; ACC 0.96875 loss: [0.12192786] <NDArray 1 @gpu(0)> Epoch: 0; Batch 500; ACC 0.96875 loss: [0.09210478] <NDArray 1 @gpu(0)> Epoch: 0; Batch 550; ACC 0.9375 loss: [0.13728428] <NDArray 1 @gpu(0)> Epoch: 0; Batch 600; ACC 0.96875 loss: [0.0762211] <NDArray 1 @gpu(0)> Epoch: 0; Batch 650; ACC 0.96875 loss: [0.12162409] <NDArray 1 @gpu(0)> Epoch: 0; Batch 700; ACC 1.0 loss: [0.04334489] <NDArray 1 @gpu(0)> Epoch: 0; Batch 750; ACC 1.0 loss: [0.06458903] <NDArray 1 @gpu(0)> Epoch: 0; Batch 800; ACC 0.96875 loss: [0.07410634] <NDArray 1 @gpu(0)> Epoch: 0; Batch 850; ACC 0.96875 loss: [0.14233188] <NDArray 1 @gpu(0)> 6.模型预测

for batch_num, itr in enumerate(test_iter_2): data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) output = net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print()

运行输出

Epoch: 0; Batch 0; ACC 0.96875 loss: [0.22968824] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.96875 loss: [0.05668993] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.96875 loss: [0.08171713] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 1.0 loss: [0.02264522] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 0.96875 loss: [0.080383] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 1.0 loss: [0.03774196] <NDArray 1 @gpu(0)> 7.模型保存

net.export(S_MODEL_PATH, epoch=N_EPOCH) # 保存模型结构和全部参数 8.模型加载和加载模型使用

print("load net and do test") load_net = mx.gluon.nn.SymbolBlock.imports(S_SYM_PATH, ['data'], S_PARAMS_PATH, ctx=CTX) # 加载模型 print("load ok") for batch_num, itr in enumerate(test_iter): # Test data = itr.data[0].as_in_context(CTX) label = itr.label[0].as_in_context(CTX) output = load_net(data) # Run the forward pass loss = softmax_cross_entropy(output, label) # Compute the loss if batch_num % 50 == 0: # Print loss once in a while curr_loss = mx.nd.mean(loss) # .asscalar() pred = mx.nd.argmax(output, axis=1) np_pred, np_lable = pred.asnumpy(), label.asnumpy() f_acc = accuracy_score(np_lable, np_pred) print(f"Epoch: {epoch}; Batch {batch_num}; ACC {f_acc}") print(f"loss: {curr_loss}") print() print("finish")

运行输出

load net and do test load ok Epoch: 0; Batch 0; ACC 0.96875 loss: [0.22968824] <NDArray 1 @gpu(0)> Epoch: 0; Batch 50; ACC 0.96875 loss: [0.05668993] <NDArray 1 @gpu(0)> Epoch: 0; Batch 100; ACC 0.96875 loss: [0.08171713] <NDArray 1 @gpu(0)> Epoch: 0; Batch 150; ACC 1.0 loss: [0.02264522] <NDArray 1 @gpu(0)> Epoch: 0; Batch 200; ACC 0.96875 loss: [0.080383] <NDArray 1 @gpu(0)> Epoch: 0; Batch 250; ACC 1.0 loss: [0.03774196] <NDArray 1 @gpu(0)> finish 9.导出ONNX

if platform.system().lower() != 'windows': mx.onnx.export_model(S_SYM_PATH, S_PARAMS_PATH, [(32, 1, 28, 28)], [np.float32], S_ONNX_MODEL_PATH, verbose=True, dynamic=True)

运行输出

INFO:root:Converting json and weight file to sym and params INFO:root:Converting idx: 0, op: null, name: data INFO:root:Converting idx: 1, op: null, name: hybridsequential0_conv0_weight INFO:root:Converting idx: 2, op: null, name: hybridsequential0_conv0_bias INFO:root:Converting idx: 3, op: Convolution, name: hybridsequential0_conv0_fwd INFO:root:Converting idx: 4, op: Activation, name: hybridsequential0_conv0_relu_fwd INFO:root:Converting idx: 5, op: Pooling, name: hybridsequential0_pool0_fwd INFO:root:Converting idx: 6, op: Flatten, name: hybridsequential0_flatten0_flatten0 INFO:root:Converting idx: 7, op: null, name: hybridsequential0_dense0_weight INFO:root:Converting idx: 8, op: null, name: hybridsequential0_dense0_bias INFO:root:Converting idx: 9, op: FullyConnected, name: hybridsequential0_dense0_fwd INFO:root:Converting idx: 10, op: Activation, name: hybridsequential0_dense0_relu_fwd INFO:root:Converting idx: 11, op: null, name: hybridsequential0_dense1_weight INFO:root:Converting idx: 12, op: null, name: hybridsequential0_dense1_bias INFO:root:Converting idx: 13, op: FullyConnected, name: hybridsequential0_dense1_fwd INFO:root:Output node is: hybridsequential0_dense1_fwd INFO:root:Input shape of the model [(32, 1, 28, 28)] INFO:root:Exported ONNX file ./mxnet_cnn.onnx saved to disk 10. 加载ONNX并运行

if platform.system().lower() != 'windows': model = onnx.load(S_ONNX_MODEL_PATH) print(onnx.checker.check_model(model)) # Check that the model is well formed # print(onnx.helper.printable_graph(model.graph)) # Print a human readable representation of the graph ls_input_name, ls_output_name = [input.name for input in model.graph.input], [output.name for output in model.graph.output] print("input name ", ls_input_name) print("output name ", ls_output_name) s_input_name = ls_input_name[0] x_input = X_train[:N_BATCH*2, :, :, :].astype(np.float32) ort_val = ort.OrtValue.ortvalue_from_numpy(x_input, S_DEVICE, N_DEVICE_ID) print("val device ", ort_val.device_name()) print("val shape ", ort_val.shape()) print("val data type ", ort_val.data_type()) print("is_tensor ", ort_val.is_tensor()) print("array_equal ", np.array_equal(ort_val.numpy(), x_input)) providers = 'CUDAExecutionProvider' if S_DEVICE == "cuda" else 'CPUExecutionProvider' print("providers ", providers) ort_session = ort.InferenceSession(S_ONNX_MODEL_PATH, providers=[providers]) # gpu运行 ort_session.set_providers([providers]) outputs = ort_session.run(None, {s_input_name: ort_val}) print("sess env ", ort_session.get_providers()) print(type(outputs)) print(outputs[0]) ''' For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. '''

运行输出

None input name ['data', 'hybridsequential0_conv0_weight', 'hybridsequential0_conv0_bias', 'hybridsequential0_dense0_weight', 'hybridsequential0_dense0_bias', 'hybridsequential0_dense1_weight', 'hybridsequential0_dense1_bias'] output name ['hybridsequential0_dense1_fwd'] val device cuda val shape [64, 1, 28, 28] val data type tensor(float) is_tensor True array_equal True providers CUDAExecutionProvider sess env ['CUDAExecutionProvider', 'CPUExecutionProvider'] <class 'list'> [[-2.69336128e+00 8.42524242e+00 -3.34120363e-01 -1.17912292e+00 3.82278800e-01 -3.60794234e+00 3.58125120e-01 -2.58064723e+00 1.55215383e+00 -2.03553891e+00] [ 1.02665892e+01 -6.65782404e+00 -2.04501271e-01 -2.25653172e+00 -6.31941366e+00 1.13084137e+00 -3.83885235e-01 8.22283030e-01 -1.21192622e+00 3.33601260e+00] [-3.27186418e+00 1.00050325e+01 5.39114550e-02 -1.44938648e+00 -9.89762247e-01 -2.09957671e+00 -1.49389958e+00 6.52510405e-01 1.73153889e+00 -1.25597775e+00] [ 5.72116375e-01 -3.36192799e+00 -6.68362260e-01 -2.81247520e+00 8.36382389e+00 -3.67477946e-02 2.23792076e+00 -2.91093756e-02 -4.56922323e-01 -6.77382052e-01] [ 1.18602552e+01 -5.09683752e+00 4.54203248e-01 -2.55723000e+00 -8.68753910e+00 6.96948707e-01 -1.50591761e-01 -3.62227589e-01 9.83437955e-01 7.46711075e-01] [ 7.33289337e+00 -6.65414715e+00 1.57180679e+00 -2.62657452e+00 4.11511570e-01 -1.35336161e+00 -1.40558392e-01 3.81030589e-01 1.73799121e+00 8.02671254e-01] [-3.02898431e+00 1.26861107e+00 -2.04946566e+00 -2.52499342e-01 -2.73597687e-01 -3.01714039e+00 -7.10914516e+00 1.10452967e+01 -5.82177579e-01 1.86712158e+00] [-7.78098392e+00 -6.01984358e+00 1.23355007e+00 1.18682652e+01 -9.83472538e+00 8.27242088e+00 -1.02135544e+01 3.95661980e-01 6.63226461e+00 3.33681512e+00] [-2.72245955e+00 -6.74849796e+00 -6.24665642e+00 3.11165476e+00 -4.71174330e-01 1.22390661e+01 -1.23519528e+00 -1.24356663e+00 1.26693976e+00 5.81862879e+00] [-5.65229607e+00 -1.25138938e+00 3.68380380e+00 1.24947300e+01 -8.21508980e+00 1.61641145e+00 -8.01925087e+00 8.37018967e-01 -2.64613247e+00 7.92313635e-01] [-3.73405719e+00 -3.41621947e+00 -7.94842839e-01 4.55352879e+00 -2.28238964e+00 1.88887548e+00 -5.84129477e+00 6.03430390e-01 1.05920439e+01 2.25430655e+00] [-5.44103146e+00 -5.48421431e+00 -3.62234282e+00 1.20194650e+00 3.48899674e+00 1.50794566e+00 -6.30612850e+00 4.01568127e+00 1.61318648e+00 9.87832165e+00] [-3.34073186e+00 8.10987663e+00 -6.43497527e-01 -1.64372277e+00 -4.42907363e-01 -1.46176386e+00 -8.56327295e-01 5.20323329e-02 1.73289025e+00 -8.17061365e-01] [-6.88457203e+00 1.38391244e+00 1.33096969e+00 1.28132534e+01 -6.20939922e+00 1.48244214e+00 -6.59804583e+00 -1.38118923e+00 4.26289368e+00 -1.22962976e+00] [-6.09051991e+00 -3.15275192e+00 1.79273260e+00 9.92699528e+00 -5.97349882e+00 3.68225765e+00 -6.47421646e+00 -1.99264419e+00 2.15714622e+00 2.32836318e+00] [-3.25946307e+00 8.14360428e+00 -1.00535810e+00 -2.37552500e+00 2.38139248e+00 -2.92597318e+00 -1.54173911e+00 2.25682306e+00 -2.83430189e-01 -1.33554244e+00] [-2.99147058e+00 3.86941671e+00 8.82810593e+00 2.20121431e+00 -8.40485859e+00 -8.66728902e-01 -5.97998762e+00 -5.21699572e+00 5.80638123e+00 -2.57314467e+00] [ 8.64277363e+00 -4.99241495e+00 2.84688592e+00 -4.15350378e-01 -1.87728360e-01 -2.40291572e+00 4.42544132e-01 -4.54446167e-01 -1.88113344e+00 -1.23334014e+00] [-2.00169897e+00 -2.65497804e+00 1.18750989e+00 9.70900059e-01 -4.53840446e+00 -2.65584946e+00 -8.23472023e+00 9.93836498e+00 -5.57100773e-01 3.42955470e+00] [-3.57249069e+00 -5.03176594e+00 -1.79369414e+00 -5.03321826e-01 -1.97100627e+00 9.01608944e+00 6.62497377e+00 -5.48222637e+00 6.09256268e+00 -4.71334040e-01] [-5.27715540e+00 -7.84428477e-01 -6.26944721e-01 3.87298250e+00 -1.88836837e+00 1.15252662e+00 -2.98473048e+00 -3.10233998e+00 9.71112537e+00 3.10839200e+00] [-9.50223565e-01 -6.47654009e+00 2.26750326e+00 1.95419586e+00 1.68217969e+00 1.66003108e+00 9.82697105e+00 -9.94868219e-01 -2.03924966e+00 -1.88321277e-01] [-3.11575246e+00 3.43664408e+00 1.19877796e+01 4.36916590e+00 -1.17812777e+01 -1.69431508e+00 -5.82668829e+00 -5.09948444e+00 4.15738583e+00 -4.30461359e+00] [ 9.72177792e+00 -5.31352401e-01 -1.21784186e+00 -1.07392669e+00 -7.11223555e+00 1.67838800e+00 1.01826215e+00 -8.88240516e-01 6.95959151e-01 2.38748863e-01] [-2.06619406e+00 1.86608231e+00 1.12100420e+01 4.22539425e+00 -1.21493711e+01 -4.57662535e+00 -6.88935089e+00 -9.81215835e-01 3.87611055e+00 -3.28470826e+00] [-6.73031902e+00 -2.54390073e+00 -1.10151446e+00 1.51524162e+01 -1.10052080e+01 6.60323954e+00 -7.94388771e+00 3.31939721e+00 -1.40840662e+00 2.65730071e+00] [-1.96954179e+00 -1.13817227e+00 9.40351069e-01 -1.75684047e+00 3.60373807e+00 2.01377797e+00 1.00558109e+01 -1.10547984e+00 5.17374456e-01 -3.94047165e+00] [-5.81787634e+00 -1.20211565e+00 -3.53216052e+00 1.17569458e+00 4.21314144e+00 -2.53644252e+00 -7.64466667e+00 4.19782829e+00 4.28840429e-01 1.04579344e+01] [-4.20310974e+00 -3.19272375e+00 -4.62792778e+00 2.71683741e+00 4.43899345e+00 3.31357956e-01 -6.24839544e+00 3.80388188e+00 -1.22620119e-02 9.65024757e+00] [-8.26945066e-01 -5.25947523e+00 3.72772887e-02 2.30585241e+00 -4.95726252e+00 -1.19987357e+00 -1.20395079e+01 1.53253164e+01 -2.10372299e-01 1.89387524e+00] [-5.09596729e+00 -7.76027665e-02 -9.53466833e-01 2.89041376e+00 -1.50858855e+00 2.27854323e+00 -1.95591903e+00 -3.15785193e+00 1.00103540e+01 1.08987451e+00] [-4.01680946e-01 -4.62062168e+00 3.90530303e-02 -1.66790807e+00 5.43311167e+00 -1.78802896e+00 -2.88405848e+00 2.93439984e+00 -2.16558409e+00 8.71198368e+00] [-1.29969406e+00 -3.92871022e+00 -3.82151055e+00 -2.93831253e+00 1.03674269e+01 8.88044477e-01 7.88922787e-01 3.86107159e+00 1.60807288e+00 -3.76913965e-01] [-2.25020099e+00 -8.17249107e+00 -1.82360613e+00 -5.90175152e-01 4.72407389e+00 9.39436078e-01 -3.85674310e+00 3.95303702e+00 1.83473241e+00 9.13874054e+00] [-3.26617742e+00 -2.91517663e+00 8.37770653e+00 1.61820054e-01 -2.98638320e+00 -2.47211266e+00 5.08574843e-01 4.65608168e+00 2.66001201e+00 -4.67262363e+00] [-2.29874635e+00 7.77097034e+00 1.11359918e+00 -2.06103897e+00 -7.61267126e-01 1.00877440e+00 1.47708499e+00 -1.20483887e+00 1.99922264e+00 -3.81118345e+00] [-6.87821198e+00 -9.18823600e-01 2.16773844e+00 1.07671242e+01 -7.48823595e+00 2.90310860e+00 -1.02075748e+01 -3.83400464e+00 4.76818371e+00 4.06564474e+00] [-2.06487226e+00 8.76828384e+00 1.10449910e+00 -2.29669046e+00 -1.15668392e+00 -2.50351834e+00 -1.69508122e-02 -1.05916834e+00 1.91057479e+00 -2.64592767e+00] [-2.24318981e+00 9.02024174e+00 1.38990092e+00 -2.72154903e+00 1.46101296e-01 -4.43454313e+00 -8.21092844e-01 -2.40808502e-01 3.36577922e-01 -2.63193059e+00] [-4.35961342e+00 -7.74704576e-01 -2.74345660e+00 -3.27951574e+00 1.50971518e+01 -2.80669570e+00 -1.28740633e+00 3.94157290e+00 4.20372874e-01 8.37333024e-01] [-2.80749345e+00 -3.33036280e+00 -1.00865018e+00 4.57633829e+00 -5.03952360e+00 2.93345642e+00 -8.54609489e+00 2.26549125e+00 4.73208952e+00 5.93849993e+00] [-3.31042576e+00 9.97719002e+00 4.38573778e-01 -1.35296178e+00 -1.21057940e+00 -2.46178842e+00 -8.34564090e-01 2.19030753e-01 2.03147411e+00 -1.80211437e+00] [-2.14534068e+00 -2.93023801e+00 4.41405416e-01 -2.29865336e+00 1.47422533e+01 -1.86358702e+00 4.61042017e-01 -6.20108247e-01 -1.36792552e+00 2.14018896e-01] [-2.64241481e+00 -2.28332114e+00 2.01109338e+00 9.67352509e-01 6.09287119e+00 -2.35626236e-01 -3.02941656e+00 4.32772923e+00 -4.63955021e+00 3.73136783e+00] [-4.55847168e+00 1.04014337e+00 9.12987328e+00 2.06433630e+00 -1.67355919e+00 -1.49593079e+00 4.09124941e-01 2.41894865e+00 -6.86871633e-02 -4.42179346e+00] [-6.10608578e-01 -2.73860097e+00 -1.09864855e+00 -5.68899512e-01 2.45831108e+00 5.50326490e+00 1.22601585e+01 -5.23877192e+00 2.11066246e+00 -2.98584485e+00] [-5.96745872e+00 -1.91458237e+00 -3.10774088e-01 1.00216856e+01 -3.81997776e+00 7.14399862e+00 -4.61386251e+00 -5.18248987e+00 4.25162363e+00 1.18878789e-01] [-4.24126434e+00 -9.63249326e-01 1.06391013e+00 4.45316315e+00 -4.47125340e+00 -1.21906054e+00 -9.75789547e+00 1.10335569e+01 -1.17632782e+00 8.78942788e-01] [ 3.76867175e-01 -4.75102758e+00 -2.59345794e+00 -3.96257102e-01 -7.50329159e-03 1.81642962e+00 -6.01041269e+00 9.97849655e+00 -2.57468176e+00 5.00644207e+00] [-2.21995306e+00 -4.23465443e+00 -2.19536662e+00 -3.71420813e+00 1.49460211e+01 -2.73240638e+00 3.03538777e-02 4.29108334e+00 1.21963143e+00 6.85284913e-01] [-1.93794692e+00 -3.39166284e+00 -3.41372967e-01 -2.16144085e-01 1.32588074e-01 -3.83050554e-02 -7.32452822e+00 9.68561840e+00 -1.16044319e+00 3.63913298e+00] [-3.31972694e+00 -4.84879112e+00 -3.41381001e+00 1.93332338e+00 -5.16150045e+00 1.05730495e+01 1.52961123e+00 -4.64916992e+00 4.11477995e+00 4.80105543e+00] [-3.22445488e+00 1.01509037e+01 6.03635088e-02 -1.48001885e+00 -3.28380197e-01 -2.36782789e+00 -8.66441727e-01 6.68077886e-01 1.45576179e+00 -1.95271623e+00] [-5.64618587e+00 -3.71156931e+00 -2.31174397e+00 1.76912701e+00 2.95111752e+00 2.09562635e+00 -5.34609461e+00 -2.59399921e-01 1.10373080e+00 1.10444403e+01] [ 1.32743053e+01 -5.11389780e+00 1.10446036e+00 -5.01595545e+00 -4.87907743e+00 4.46035750e-02 4.57046556e+00 -1.76434004e+00 -5.40824793e-02 -1.01547205e+00] [-4.27589798e+00 1.26832044e+00 6.49948978e+00 -8.06193352e-02 1.34645328e-01 6.92090929e-01 -8.92272711e-01 2.15252757e+00 -9.95365143e-01 -2.46636438e+00] [-5.08349514e+00 -1.79646879e-01 9.57399654e+00 3.35643005e+00 -3.42183971e+00 -2.33653522e+00 -1.98645079e+00 1.50538552e+00 -1.43313253e+00 -9.75638926e-01] [-8.48450565e+00 -4.32870531e+00 -1.54253757e+00 1.61029205e+01 -8.41084957e+00 5.45092726e+00 -1.00705996e+01 2.31331086e+00 1.05400957e-01 5.19563723e+00] [-4.89484310e+00 -4.33120441e+00 -3.71932483e+00 -1.18670315e-01 4.70187807e+00 2.67010808e-01 -5.52650118e+00 5.62736416e+00 5.05499423e-01 8.69004250e+00] [-1.94048929e+00 7.98250961e+00 -3.75457823e-01 -2.09433126e+00 -4.63896915e-02 -2.64650702e+00 -5.73348761e-01 -1.34597674e-01 1.26143038e+00 -1.32178509e+00] [-6.36190224e+00 5.44552183e+00 1.89667892e+00 3.91665459e+00 -1.99860716e+00 2.76620483e+00 3.57649279e+00 -3.53898597e+00 -1.21411644e-01 -2.60307193e+00] [-3.60333633e+00 4.92228556e+00 2.87770915e+00 1.31902504e+00 -4.28756446e-01 -3.29862523e+00 -2.29294825e+00 -1.10349190e+00 3.81862259e+00 -1.23572731e+00] [ 1.06328213e+00 -6.86575174e+00 -3.61938500e+00 1.15000987e+00 -2.32747698e+00 1.32029047e+01 -3.19671059e+00 8.91836137e-02 4.70500660e+00 2.51928687e+00] [ 1.00363417e+01 -8.02793884e+00 9.38569680e-02 -2.78522468e+00 -2.84671545e+00 2.95867848e+00 4.23545599e+00 2.52083421e+00 -3.35666537e+00 1.45630157e+00]] 你甚至不愿意Start的Github

ai_fast_handbook

标签:MNIST预测