前几天参照网友教程,我是Anaconda Prompt终端命令下进行的,在生成模型文件大小有问题,别人都是MB级别,我的是KB,估计是有问题,今晚又折腾了下,参照网友wangerxian帖子https://bbs.eeworld.com.cn/thread-1278516-1-1.html操作,先安装了PyCharm,然后再执行py。
执行的代码直接copy网友knv的,用GPU参与训练的,GPU占有率大概能到45%左右,CPU几乎跑满了。
生成文件如下:
文件大小应该是对了,跟knv生成的模型大小基本一致。
接下来参照其他网友的执行看看:
1、新建个工程
环境配置参照上面图片,Path to conda那里需要选择你Anaconda安装目录,别搞错了。Environment选择你生成的虚拟python环境,要提前在Prompt里面安装好相关包文件。
2、新建Python文件
①位置工程文件夹上右键,②新建,③选择Python文件
输入个文件名
把代码拷贝粘贴进来,然后点开左下角终端按钮
执行python文件
文件上右键,选择运行
执行结果:
pt文件生成了,但是转onnx时候报错,也不知道啥情况,有可能跟软件版本有关。
再来试试其他GPU训练的代码:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchsummary import summary
import time
# 创建神经网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.output_layer = nn.Linear(32*7*7, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = x.reshape(x.size(0), -1)
output = self.output_layer(x)
return output
# 超参数
EPOCH = 2
BATCH_SIZE = 100
LR = 0.001
DOWNLOAD = True # 若已经下载mnist数据集则设为False
# 下载mnist数据
train_data = datasets.MNIST(
root='./data', # 保存路径
train=True, # True表示训练集,False表示测试集
transform=transforms.ToTensor(), # 将0~255压缩为0~1
download=DOWNLOAD
)
# 旧的写法
print(train_data.train_data.size())
print(train_data.train_labels.size())
# 新的写法
print(train_data.data.size())
print(train_data.targets.size())
# 打印部分数据集的图片
for i in range(2):
print(train_data.targets[i].item())
plt.imshow(train_data.data[i].numpy(), cmap='gray')
plt.show()
# DataLoader
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
# 如果train_data下载好后,test_data也就下载好了
test_data = datasets.MNIST(
root='./data',
train=False
)
print(test_data.data.size())
print(test_data.targets.size())
# 新建网络
cnn = CNN()
# 将神经网络移到GPU上
cnn.cuda()
print(cnn)
# 查看网络的结构
model = CNN()
if torch.cuda.is_available():
model.cuda()
summary(model, input_size=(1,28,28))
# 优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
# 损失函数
loss_func = nn.CrossEntropyLoss()
# 为了节约时间,只使用测试集的前2000个数据
test_x = Variable(
torch.unsqueeze(test_data.data, dim=1),
volatile=True
).type(torch.FloatTensor)[:2000]/255 # 将将0~255压缩为0~1
test_y = test_data.targets[:2000]
# # 使用所有的测试集
# test_x = Variable(
# torch.unsqueeze(test_data.test_data, dim=1),
# volatile=True
# ).type(torch.FloatTensor)/255 # 将将0~255压缩为0~1
# test_y = test_data.test_labels
# 将测试数据移到GPU上
test_x = test_x.cuda()
test_y = test_y.cuda()
# 开始计时
start = time.time()
# 训练神经网络
for epoch in range(EPOCH):
for step, (batch_x, batch_y) in enumerate(train_loader):
# 将训练数据移到GPU上
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
output = cnn(batch_x)
loss = loss_func(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每隔50步输出一次信息
if step%50 == 0:
test_output = cnn(test_x)
# 将预测结果移到GPU上
predict_y = torch.max(test_output, 1)[1].cuda().data.squeeze()
accuracy = (predict_y == test_y).sum().item() / test_y.size(0)
print('Epoch', epoch, '|', 'Step', step, '|', 'Loss', loss.data.item(), '|', 'Test Accuracy', accuracy)
# 结束计时
end = time.time()
# 训练耗时
print('Time cost:', end - start, 's')
# 预测
test_output = cnn(test_x[:100])
# 为了将CUDA tensor转化为numpy,需要将数据移回CPU上
# 否则会报错:TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
predict_y = torch.max(test_output, 1)[1].cpu().data.numpy().squeeze()
real_y = test_y[:100].cpu().numpy()
print(predict_y)
print(real_y)
# 打印预测和实际结果
for i in range(10):
print('Predict', predict_y[i])
print('Real', real_y[i])
plt.imshow(test_data.data[i].numpy(), cmap='gray')
plt.show()
执行后报错,
这里可以看到是没有torchsummary包,通过conda尝试,发现找不到这个包,无奈只能用pip来安装,
pip install torchsummary
再次执行,注意,如果你已经执行下载过mnist数据集,则39行为false,执行后这里会打开图片样本。
关闭图片窗口后会报错,搞不定了,继续换其他代码试试。
这里采用luyism的代码试试
https://bbs.eeworld.com.cn/thread-1278192-1-1.html
这里提示报错,原因是没安装onnxruntime包,在conda下安装下再次尝试可以生成模型文件,
这里尝试下qiao--网友的代码,
# 导包
import torch
import torch.nn as nn # 神经网络
import torch.optim as optim # 定义优化器
from torchvision import datasets, transforms # 数据集 transforms完成对数据的处理
# 定义超参数
input_size = 28 * 28 # 输入大小
hidden_size = 512 # 隐藏层大小
num_classes = 10 # 输出大小(类别数)
batch_size = 100 # 批大小
learning_rate = 0.001 # 学习率
num_epochs = 10 # 训练轮数
# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='../data/mnist', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='../data/mnist', train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) # 一批数据为100个
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
# 定义 MLP 网络
class MLP(nn.Module):
# 初始化方法
# input_size 输入数据的维度
# hidden_size 隐藏层的大小
# num_classes 输出分类的数量
def __init__(self, input_size, hidden_size, num_classes):
# 调用父类的初始化方法
super(MLP, self).__init__()
# 定义第1个全连接层
self.fc1 = nn.Linear(input_size, hidden_size)
# 定义ReLu激活函数
self.relu = nn.ReLU()
# 定义第2个全连接层
self.fc2 = nn.Linear(hidden_size, hidden_size)
# 定义第3个全连接层
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# 将输入张量展平为向量
x = x.view(x.size(0), -1)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
# 实例化 MLP 网络
model = MLP(input_size, hidden_size, num_classes)
# 现在我们已经定义了 MLP 网络并加载了 MNIST 数据集,接下来使用 PyTorch 的自动求导功能和优化器进行训练。首先,定义损失函数和优化器;然后迭代训练数据并使用优化器更新网络参数。
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
# CrossEntropyLoss = Softmax + log + nllloss
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = optim.SGD(model.parameters(),0.2)
# 训练网络
# 外层for循环控制训练的次数
# 内层for循环控制从DataLoader中循环读取数据
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, 28 * 28) # 将images转换成向量
outputs = model(images) # 将数据送到网络中
loss = criterion(outputs, labels) # 计算损失
optimizer.zero_grad() # 首先将梯度清零
loss.backward() # 反向传播
optimizer.step() # 更新参数
if (i + 1) % 100 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}],Step[{i + 1}/{len(train_loader)}],Loss:{loss.item():.4f}')
# 最后,我们可以在测试数据上评估模型的准确率:
# 测试网络
with torch.no_grad():
correct = 0
total = 0
# 从test_loader中循环读取测试数据
for images, labels in test_loader:
# 将images转换成向量
images = images.reshape(-1, 28 * 28)
# 将数据传送到网络
outputs = model(images)
# 取出最大值对应的索引 即预测值
_, predicted = torch.max(outputs.data, 1) # 返回一个元组:第一个为最大值,第二个是最大值的下标
# 累加labels数量 labels为形状为(batch_size,1)的矩阵,取size(0)就是取出batch_size的大小(该批的大小)
total += labels.size(0)
# 预测值与labels值对比 获取预测正确的数量
correct += (predicted == labels).sum().item()
# 打印最终的准确率
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
# 保存模型
# 保存模型的状态字典
torch.save(model.state_dict(), 'mnist.pth')
model.load_state_dict(torch.load('mnist.pth'))
# 将模型设置为评估模式(如果需要)
model.eval()
#导出为onnx模型
input = torch.randn(1, 28, 28)
torch.onnx.export(model, input, "mnist.onnx", verbose=True)
执行结果如下;
我发现大家生成的模型文件差别挺大的,这里模型准确率可以到98%了,10轮训练。
补充下终极任务二今天研究情况。
首先非常感谢乔帮主群里的指点,最后发现我无法驱动成功sd卡的原因有三个:
sdcard类有问题,之前是从micropython下载的(也是坛友推荐的),一直没怀疑它,我用了乔帮主分享的立马就能初始化成功;
sd卡兼容性有问题,之前用的2G小容量的,在arduino下访问没问题,但是用micropython就不行了,然后换了张新的32G的就没事;
手头没杜邦线,用的4Pin端子连接线,接头上了点焊锡后查到排母里,貌似有点接触不良;
这里我把乔帮主分享的文件传上来,也给大家做参考。(大家根据自己接线方式修改Pin定义)
一、SD卡驱动
软件代码
1、测试SD卡
# Filename: tfcard_test.py
import uos # os/uos
import machine
import sdcard
from machine import SPI, Pin
spi = SPI(1, sck=Pin(10), mosi=Pin(11), miso=Pin(12))
cs = Pin(13)
sd = sdcard.SDCard(spi,cs)
# 挂载文件到sd
uos.mount(sd,"/sd")
# 列出MicroSD/TF卡中的目录文件
print(uos.listdir('/sd'))
# 写文件测试
f = open('/sd/test.txt','w',encoding='utf-8')
f.write('MicroSD/TF存储卡访问测试!')
f.close()
# 读文件测试
f = open('/sd/test.txt','r')
print(f.read())
f.close()
2、运行结果:
二、FTP服务器搭建
1、软件代码
端口及参数定义
import gc
import uos
import time
import socket
import network
from time import localtime
from machine import Pin, SPI
from micropython import const
/'初始化引脚'/
_LED_PIN = const(25) # LED 指示灯引脚
_SPI_SPEED = const(2_000_000) # SPI 速率
_MOSI_PIN = const(19) # SPI MOSI 引脚
_MISO_PIN = const(16) # SPI MISO 引脚
_SCK_PIN = const(18) # SPI SCK 引脚
_CS_PIN = const(17) # SPI CS 引脚
_RST_PIN = const(20) # SPI RESET 引脚
FTP_ROOT_PATH = const("/") # FTP 根目录
month_name = ["","Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec",
]
# SPI 定义
spi = SPI(0, _SPI_SPEED, mosi=Pin(_MOSI_PIN), miso=Pin(_MISO_PIN), sck=Pin(_SCK_PIN))
nic = None
W5500初始化
""" W5500 初始化 """
def w5x00_init():
global nic
# 网口初始化
nic = network.WIZNET5K(spi, Pin(_CS_PIN), Pin(_RST_PIN)) # spi,cs,reset pin
nic.active(True)
# 配置网络
# If you use the Dynamic IP(DHCP), you must use the "nic.ifconfig('dhcp')".
nic.ifconfig("dhcp")
# If you use the Static IP, you must use the "nic.ifconfig("IP","subnet","Gateway","DNS")".
# nic.ifconfig(('192.168.0.106','255.255.255.0','192.168.0.1','114.114.114.114'))
while not nic.isconnected():
time.sleep(1)
print(nic.regs())
print("IP地址: %s" % nic.ifconfig()[0])
print("子网掩码: %s" % nic.ifconfig()[1])
print("网关: %s" % nic.ifconfig()[2])
print("DNS: %s" % nic.ifconfig()[3])
文件列表请求响应函数
'/ 响应文件列表请求 /'
def send_list_data(path, dataclient, full):
try: # whether path is a directory name
for fname in uos.listdir(path):
dataclient.sendall(make_description(path, fname, full))
except: # path may be a file name or pattern
pattern = path.split("/")[-1]
path = path[: -(len(pattern) + 1)]
if path == "":
path = "/"
for fname in uos.listdir(path):
if fncmp(fname, pattern) == True:
dataclient.sendall(make_description(path, fname, full))
目录详情
""" 列出目录详情 """
def make_description(path, fname, full):
if full:
stat = uos.stat(get_absolute_path(path, fname))
file_permissions = (
"drwxr-xr-x" if (stat[0] & 0o170000 == 0o040000) else "-rw-r--r--"
)
file_size = stat[6]
tm = localtime(stat[7])
if tm[0] != localtime()[0]:
description = "{} 1 owner group {:>10} {} {:2} {:>5} {}\r\n".format(
file_permissions, file_size, month_name[tm[1]], tm[2], tm[0], fname
)
else:
description = (
"{} 1 owner group {:>10} {} {:2} {:02}:{:02} {}\r\n".format(
file_permissions,
file_size,
month_name[tm[1]],
tm[2],
tm[3],
tm[4],
fname,
)
)
else:
description = fname + "\r\n"
return description
文件发送函数
""" 发送文件数据 """
def send_file_data(path, dataclient):
try:
with open(path, "rb") as file:
chunk = file.read(512)
print("chunk 0: ", len(chunk))
while len(chunk) > 0:
print("chunk: ", len(chunk))
dataclient.sendall(chunk)
chunk = file.read(512)
except Exception as err:
print("error: ", err.args, err.value, err.errno)
文件数据保存
""" 保存文件上传数据 """
def save_file_data(path, dataclient, mode):
with open(path, mode) as file:
chunk = dataclient.read(512)
while len(chunk) > 0:
file.write(chunk)
chunk = dataclient.read(512)
文件路径获取函数
""" 获取文件绝对路径 """
def get_absolute_path(cwd, payload):
# Just a few special cases "..", "." and ""
# If payload start's with /, set cwd to /
# and consider the remainder a relative path
if payload.startswith("/"):
cwd = "/"
for token in payload.split("/"):
if token == "..":
if cwd != "/":
cwd = "/".join(cwd.split("/")[:-1])
if cwd == "":
cwd = "/"
elif token != "." and token != "":
if cwd == "/":
cwd += token
else:
cwd = cwd + "/" + token
return cwd
文件名比较函数
""" 文件名比较 """
def fncmp(fname, pattern):
pi = 0
si = 0
while pi < len(pattern) and si < len(fname):
if (fname[si] == pattern[pi]) or (pattern[pi] == "?"):
si += 1
pi += 1
else:
if pattern[pi] == "*": # recurse
if (pi + 1) == len(pattern):
return True
while si < len(fname):
if fncmp(fname[si:], pattern[pi + 1 :]) == True:
return True
else:
si += 1
return False
else:
return False
if pi == len(pattern.rstrip("*")) and si == len(fname):
return True
else:
return False
FTP服务启动
""" 启动FTP服务 """
def ftpserver():
DATA_PORT = 13333
ftpsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
datasocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ftpsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
datasocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ftpsocket.bind(socket.getaddrinfo("0.0.0.0", 21)[0][4])
datasocket.bind(socket.getaddrinfo("0.0.0.0", DATA_PORT)[0][4])
ftpsocket.listen(1)
datasocket.listen(1)
datasocket.settimeout(10)
print("FTP服务启动成功!监听端口:21")
msg_250_OK = "250 OK\r\n"
msg_550_fail = "550 Failed\r\n"
try:
dataclient = None
fromname = None
while True:
cl, remote_addr = ftpsocket.accept()
cl.settimeout(300)
cwd = FTP_ROOT_PATH
try:
print("新的FTP连接来自: %s:%s" % (remote_addr[0], remote_addr[1]))
cl.sendall("220 Welcome! This is the W5500_EVB_PICO!\r\n")
while True:
gc.collect()
data = cl.readline().decode("utf-8").rstrip("\r\n")
if len(data) <= 0:
print("Client disappeared")
break
command = data.split(" ")[0].upper()
payload = data[len(command) :].lstrip()
path = get_absolute_path(cwd, payload)
print("命令={}, 参数={}, 路径={}".format(command, payload, path))
if command == "USER":
cl.sendall("230 Logged in.\r\n")
elif command == "SYST":
cl.sendall("215 UNIX Type: L8\r\n")
elif command == "NOOP":
cl.sendall("200 OK\r\n")
elif command == "FEAT":
cl.sendall("211 no-features\r\n")
elif command == "PWD":
cl.sendall('257 "{}"\r\n'.format(cwd))
elif command == "CWD":
try:
files = uos.listdir(path)
cwd = path
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "CDUP":
cwd = get_absolute_path(cwd, "..")
cl.sendall(msg_250_OK)
elif command == "TYPE":
# probably should switch between binary and not
cl.sendall("200 Transfer mode set\r\n")
elif command == "SIZE":
try:
size = uos.stat(path)[6]
cl.sendall("213 {}\r\n".format(size))
except:
cl.sendall(msg_550_fail)
elif command == "QUIT":
cl.sendall("221 Bye.\r\n")
break
elif command == "PASV":
addr = nic.ifconfig()[0]
cl.sendall(
"227 Entering Passive Mode ({},{},{}).\r\n".format(
addr.replace(".", ","), DATA_PORT >> 8, DATA_PORT % 256
)
)
dataclient, data_addr = datasocket.accept()
print("新的FTP数据连接来自: %s:%s" % (data_addr[0], data_addr[1]))
elif command == "LIST" or command == "NLST":
if not payload.startswith("-"):
place = path
else:
place = cwd
try:
send_list_data(
place, dataclient, command == "LIST" or payload == "-l"
)
cl.sendall("150 Here comes the directory listing.\r\n")
cl.sendall("226 Listed.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "RETR":
try:
send_file_data(path, dataclient)
cl.sendall("150 Opening data connection.\r\n")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "STOR":
try:
cl.sendall("150 Ok to send data.\r\n")
save_file_data(path, dataclient, "wb")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "APPE":
try:
cl.sendall("150 Ok to send data.\r\n")
save_file_data(path, dataclient, "a")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "DELE":
try:
uos.remove(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "RMD":
try:
uos.rmdir(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "MKD":
try:
uos.mkdir(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "RNFR":
fromname = path
cl.sendall("350 Rename from\r\n")
elif command == "RNTO":
if fromname is not None:
try:
uos.rename(fromname, path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
else:
cl.sendall(msg_550_fail)
fromname = None
else:
cl.sendall("502 Unsupported command.\r\n")
# print("Unsupported command {} with payload {}".format(command, payload))
except Exception as err:
print(err)
finally:
cl.close()
cl = None
finally:
datasocket.close()
ftpsocket.close()
if dataclient is not None:
dataclient.close()
if __name__ == "__main__":
print("run in main")
w5x00_init() # 初始化网络
ftpserver() # 运行 FTP Server
三、视频演示
[localvideo]fa1116faca696981b17a0e147be51e79[/localvideo]
四、完整代码
import gc
import uos
import time
import socket
import network
from time import localtime
from machine import Pin, SPI
from micropython import const
import sdcard
"""初始化引脚"""
_LED_PIN = const(25) # LED 指示灯引脚
_SPI_SPEED = const(2_000_000) # SPI 速率
_MOSI_PIN = const(19) # SPI MOSI 引脚
_MISO_PIN = const(16) # SPI MISO 引脚
_SCK_PIN = const(18) # SPI SCK 引脚
_CS_PIN = const(17) # SPI CS 引脚
_RST_PIN = const(20) # SPI RESET 引脚
FTP_ROOT_PATH = const("/sd") # FTP 根目录
month_name = ["","Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec",
]
# SPI 定义
spi = SPI(0, _SPI_SPEED, mosi=Pin(_MOSI_PIN), miso=Pin(_MISO_PIN), sck=Pin(_SCK_PIN))
nic = None
spi1 = SPI(1, sck=Pin(10), mosi=Pin(11), miso=Pin(12))
cs = Pin(13)
sd = sdcard.SDCard(spi1,cs)
# 挂载文件到sd
uos.mount(sd,"/sd")
# 列出MicroSD/TF卡中的目录文件
print(uos.listdir('/sd'))
# 写文件测试
f = open('/sd/test.txt','w',encoding='utf-8')
f.write('MicroSD/TF存储卡访问测试!')
f.close()
# 读文件测试
f = open('/sd/test.txt','r')
print(f.read())
f.close()
""" W5500 初始化 """
def w5x00_init():
global nic
# 网口初始化
nic = network.WIZNET5K(spi, Pin(_CS_PIN), Pin(_RST_PIN)) # spi,cs,reset pin
nic.active(True)
# 配置网络
# If you use the Dynamic IP(DHCP), you must use the "nic.ifconfig('dhcp')".
nic.ifconfig("dhcp")
# If you use the Static IP, you must use the "nic.ifconfig("IP","subnet","Gateway","DNS")".
# nic.ifconfig(('192.168.0.106','255.255.255.0','192.168.0.1','114.114.114.114'))
while not nic.isconnected():
time.sleep(1)
print(nic.regs())
print("IP地址: %s" % nic.ifconfig()[0])
print("子网掩码: %s" % nic.ifconfig()[1])
print("网关: %s" % nic.ifconfig()[2])
print("DNS: %s" % nic.ifconfig()[3])
'/ 响应文件列表请求 /'
def send_list_data(path, dataclient, full):
try: # whether path is a directory name
for fname in uos.listdir(path):
dataclient.sendall(make_description(path, fname, full))
except: # path may be a file name or pattern
pattern = path.split("/")[-1]
path = path[: -(len(pattern) + 1)]
if path == "":
path = "/"
for fname in uos.listdir(path):
if fncmp(fname, pattern) == True:
dataclient.sendall(make_description(path, fname, full))
""" 列出目录详情 """
def make_description(path, fname, full):
if full:
stat = uos.stat(get_absolute_path(path, fname))
file_permissions = (
"drwxr-xr-x" if (stat[0] & 0o170000 == 0o040000) else "-rw-r--r--"
)
file_size = stat[6]
tm = localtime(stat[7])
if tm[0] != localtime()[0]:
description = "{} 1 owner group {:>10} {} {:2} {:>5} {}\r\n".format(
file_permissions, file_size, month_name[tm[1]], tm[2], tm[0], fname
)
else:
description = (
"{} 1 owner group {:>10} {} {:2} {:02}:{:02} {}\r\n".format(
file_permissions,
file_size,
month_name[tm[1]],
tm[2],
tm[3],
tm[4],
fname,
)
)
else:
description = fname + "\r\n"
return description
""" 发送文件数据 """
def send_file_data(path, dataclient):
try:
with open(path, "rb") as file:
chunk = file.read(512)
print("chunk 0: ", len(chunk))
while len(chunk) > 0:
print("chunk: ", len(chunk))
dataclient.sendall(chunk)
chunk = file.read(512)
except Exception as err:
print("error: ", err.args, err.value, err.errno)
""" 保存文件上传数据 """
def save_file_data(path, dataclient, mode):
with open(path, mode) as file:
chunk = dataclient.read(512)
while len(chunk) > 0:
file.write(chunk)
chunk = dataclient.read(512)
""" 获取文件绝对路径 """
def get_absolute_path(cwd, payload):
# Just a few special cases "..", "." and ""
# If payload start's with /, set cwd to /
# and consider the remainder a relative path
if payload.startswith("/"):
cwd = "/"
for token in payload.split("/"):
if token == "..":
if cwd != "/":
cwd = "/".join(cwd.split("/")[:-1])
if cwd == "":
cwd = "/"
elif token != "." and token != "":
if cwd == "/":
cwd += token
else:
cwd = cwd + "/" + token
return cwd
""" 文件名比较 """
def fncmp(fname, pattern):
pi = 0
si = 0
while pi < len(pattern) and si < len(fname):
if (fname[si] == pattern[pi]) or (pattern[pi] == "?"):
si += 1
pi += 1
else:
if pattern[pi] == "*": # recurse
if (pi + 1) == len(pattern):
return True
while si < len(fname):
if fncmp(fname[si:], pattern[pi + 1 :]) == True:
return True
else:
si += 1
return False
else:
return False
if pi == len(pattern.rstrip("*")) and si == len(fname):
return True
else:
return False
""" 启动FTP服务 """
def ftpserver():
DATA_PORT = 13333
ftpsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
datasocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ftpsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
datasocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
ftpsocket.bind(socket.getaddrinfo("0.0.0.0", 21)[0][4])
datasocket.bind(socket.getaddrinfo("0.0.0.0", DATA_PORT)[0][4])
ftpsocket.listen(1)
datasocket.listen(1)
datasocket.settimeout(10)
print("FTP服务启动成功!监听端口:21")
msg_250_OK = "250 OK\r\n"
msg_550_fail = "550 Failed\r\n"
try:
dataclient = None
fromname = None
while True:
cl, remote_addr = ftpsocket.accept()
cl.settimeout(300)
cwd = FTP_ROOT_PATH
try:
print("新的FTP连接来自: %s:%s" % (remote_addr[0], remote_addr[1]))
cl.sendall("220 Welcome! This is the W5500_EVB_PICO!\r\n")
while True:
gc.collect()
data = cl.readline().decode("utf-8").rstrip("\r\n")
if len(data) <= 0:
print("Client disappeared")
break
command = data.split(" ")[0].upper()
payload = data[len(command) :].lstrip()
path = get_absolute_path(cwd, payload)
print("命令={}, 参数={}, 路径={}".format(command, payload, path))
if command == "USER":
cl.sendall("230 Logged in.\r\n")
elif command == "SYST":
cl.sendall("215 UNIX Type: L8\r\n")
elif command == "NOOP":
cl.sendall("200 OK\r\n")
elif command == "FEAT":
cl.sendall("211 no-features\r\n")
elif command == "PWD":
cl.sendall('257 "{}"\r\n'.format(cwd))
elif command == "CWD":
try:
files = uos.listdir(path)
cwd = path
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "CDUP":
cwd = get_absolute_path(cwd, "..")
cl.sendall(msg_250_OK)
elif command == "TYPE":
# probably should switch between binary and not
cl.sendall("200 Transfer mode set\r\n")
elif command == "SIZE":
try:
size = uos.stat(path)[6]
cl.sendall("213 {}\r\n".format(size))
except:
cl.sendall(msg_550_fail)
elif command == "QUIT":
cl.sendall("221 Bye.\r\n")
break
elif command == "PASV":
addr = nic.ifconfig()[0]
cl.sendall(
"227 Entering Passive Mode ({},{},{}).\r\n".format(
addr.replace(".", ","), DATA_PORT >> 8, DATA_PORT % 256
)
)
dataclient, data_addr = datasocket.accept()
print("新的FTP数据连接来自: %s:%s" % (data_addr[0], data_addr[1]))
elif command == "LIST" or command == "NLST":
if not payload.startswith("-"):
place = path
else:
place = cwd
try:
send_list_data(
place, dataclient, command == "LIST" or payload == "-l"
)
cl.sendall("150 Here comes the directory listing.\r\n")
cl.sendall("226 Listed.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "RETR":
try:
send_file_data(path, dataclient)
cl.sendall("150 Opening data connection.\r\n")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "STOR":
try:
cl.sendall("150 Ok to send data.\r\n")
save_file_data(path, dataclient, "wb")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "APPE":
try:
cl.sendall("150 Ok to send data.\r\n")
save_file_data(path, dataclient, "a")
cl.sendall("226 Transfer complete.\r\n")
except:
cl.sendall(msg_550_fail)
if dataclient is not None:
dataclient.close()
dataclient = None
elif command == "DELE":
try:
uos.remove(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "RMD":
try:
uos.rmdir(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "MKD":
try:
uos.mkdir(path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
elif command == "RNFR":
fromname = path
cl.sendall("350 Rename from\r\n")
elif command == "RNTO":
if fromname is not None:
try:
uos.rename(fromname, path)
cl.sendall(msg_250_OK)
except:
cl.sendall(msg_550_fail)
else:
cl.sendall(msg_550_fail)
fromname = None
else:
cl.sendall("502 Unsupported command.\r\n")
# print("Unsupported command {} with payload {}".format(command, payload))
except Exception as err:
print(err)
finally:
cl.close()
cl = None
finally:
datasocket.close()
ftpsocket.close()
if dataclient is not None:
dataclient.close()
if __name__ == "__main__":
print("run in main")
w5x00_init() # 初始化网络
ftpserver() # 运行 FTP Server