最近面临一个问题,需要在unity3D中将手绘的指令识别出来,目前就很简单的三个指令,顺时针旋转,逆时针旋转,和箭头。
我就想到了使用深度学习来给这三个指令分类,这其中要用到python的pytorch。
在两个程序如何通信的问题上,有个朋友告诉我可以用socket使两个程序通信,好就用socket

最后的效果就是这样的

unity 加载python模型 unity pytorch_unity 加载python模型


unity 加载python模型 unity pytorch_System_02

unity C#脚本代码如下(写的很粗糙,主要是实现功能(~ ̄(OO) ̄)ブ

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using System.Net;
using System.Net.Sockets;
using System;
using System.Threading;
using System.Text;
using System.IO;

public class drawline : MonoBehaviour
{
    public Camera ShootCamera; //拍照相机
    public GameObject linePrefab;
    // 线条渲染器
    private LineRenderer line;
    // 顶点下标
    private int i;
    public GameObject obj_Parent;
    private bool iscon = false;
    //创建套接字对象
    private Socket m_socket;
    private byte[] m_sendBuff;
    private byte[] m_recvBuff;
    private AsyncCallback m_recvCb;
    //private string sendstr = "歪歪歪,收得到吗";
    private void Start()
    {
        m_recvBuff = new byte[1024];
        m_sendBuff = new byte[1024];
        m_recvCb = new AsyncCallback(RecvCallBack);
        //*******************************************************************
        //服务器
        Socket lfd = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
        //创建IP地址和端口号对象
        IPAddress ip = IPAddress.Any; //IPAddress.Parse("127.0.0.1");
        //端口号
        IPEndPoint point = new IPEndPoint(ip, 6666);
        //让负责监听的socket绑定IP地址跟端口号
        lfd.Bind(point);
        Debug.Log("服务器监听启动");
        lfd.Listen(5);
        Thread th = new Thread(Listen);
        th.IsBackground = true;
        th.Start(lfd);
    }

    void Listen(object o)
    {
        Socket lfd = o as Socket;
        while (true)
        {
            Socket rfd = lfd.Accept();
            Debug.Log(rfd.RemoteEndPoint.ToString() + "已连接了");
            m_socket = rfd;
            iscon = true;
        }

    }
    void Update()
    {
        // 鼠标左键按下瞬间
        if (Input.GetMouseButtonDown(0))
        {
            GameObject go = Instantiate(linePrefab, linePrefab.transform.position, transform.rotation);
            line = go.GetComponent<LineRenderer>();
            line.transform.parent = GameObject.Find(obj_Parent.name).transform;
            // 设置材质
            // line.material = new Material(Shader.Find("Particles/Additive"));
            // 设置颜色
            line.startColor = Color.red;
            line.endColor = Color.red;
            // 设置宽度
            //line.startWidth = 0.1f;
            //line.endWidth = 0.1f;
            i = 0;
        }
        // 鼠标左键按下期间
        if (Input.GetMouseButton(0))
        {
            i++;
            // 设置顶点数
            line.positionCount = i;
            // 设置顶点位置
            line.SetPosition(i - 1, Camera.main.ScreenToWorldPoint(
                    new Vector3(Input.mousePosition.x, Input.mousePosition.y, 10)));
        }
        //按下s键,保存图片并将图片名字发送给深度学习模型
        if(Input.GetKeyDown("s"))
        {
            screenShoot();
        }
        //按下d键,清空当前的轨迹
        if (Input.GetKeyDown("d"))
        {
            Btn_Close();
        }

    }

    private void screenShoot()
    {
        Camera camera = ShootCamera;
        Rect rect = new Rect(0, 0, Screen.width, Screen.height);
        // 创建一个RenderTexture对象
        RenderTexture rt = new RenderTexture((int)rect.width, (int)rect.height, 0);
        // 临时设置相关相机的targetTexture为rt, 并手动渲染相关相机
        camera.targetTexture = rt;
        camera.Render();
        // 激活这个rt, 并从中中读取像素。
        RenderTexture.active = rt;
        Texture2D t2D = new Texture2D((int)rect.width, (int)rect.height, TextureFormat.RGB24, false);
        t2D.ReadPixels(rect, 0, 0);// 注:这个时候,它是从RenderTexture.active中读取像素
        t2D.Apply();
        // 重置相关参数,以使用camera继续在屏幕上显示
        camera.targetTexture = null;
        RenderTexture.active = null;
        Destroy(rt);
        // 最后将这些纹理数据,成一个png图片文件
        byte[] bytes = t2D.EncodeToPNG();
        string nowtime = DateTime.Now.Year.ToString() + DateTime.Now.Month.ToString() + DateTime.Now.Hour.ToString() + DateTime.Now.Minute.ToString() + DateTime.Now.Second.ToString();
        File.WriteAllBytes("Assets/gesture/" + nowtime + ".png", bytes);
        Debug.Log("按下了截图键" );
        //socket发送信息
        m_sendBuff = Encoding.UTF8.GetBytes(nowtime);
        NetworkStream netstream = new NetworkStream(m_socket);
        netstream.Write(m_sendBuff, 0, m_sendBuff.Length);
    }
    private void FixedUpdate()
    {
        if (iscon == true)
        {
            m_socket.BeginReceive(m_recvBuff, 0, m_recvBuff.Length, SocketFlags.None, m_recvCb, this);
        }
    }
    //接收到信息回调
    void RecvCallBack(IAsyncResult ar)
    {
        var len = m_socket.EndReceive(ar);
        byte[] msg = new byte[len];
        Array.Copy(m_recvBuff, msg, len);
        string msgStr = System.Text.Encoding.UTF8.GetString(msg);
        Debug.Log("图片的类型是:"+ msgStr);
        for (int i = 0; i < m_recvBuff.Length; ++i)
        {
            m_recvBuff[i] = 0;
        }
    }
    //删除obj_Parent下面的全部子物体,清空图画内容
    public void Btn_Close()
    {
        for (int i = 0; i < obj_Parent.transform.childCount; i++)
        {
            Destroy(obj_Parent.transform.GetChild(i).gameObject);
        }
    }
}

负责截图的相机放在主相机的子对象下,可以设置一下层级,让截图相机只能看到线条,可以不被其他游戏对象干扰,截出来的图片干净一点

unity 加载python模型 unity pytorch_System_03


unity 加载python模型 unity pytorch_python_04


unity 加载python模型 unity pytorch_unity_05

下面是python的代码(深度学习模型就不放出来了,而且我刚学python(°ー°〃)

classes = ['Cclockwise', 'arrow', 'clockwise']
# 读取权重
checkpoint = torch.load('E:/pyCode/weightsave.t7')
net.load_state_dict(checkpoint['state'], strict=False)  # 从字典中依次读取
net.eval()

issendtime = False
piclass = ''


# 分类单张图片
def classfypic(name):
    global piclass, issendtime
    img = cv2.imread('E:/unitylearn/draw/Assets/gesture/'+name+'.png')
    # cv2.imshow('img', img)
    # cv2.waitKey()
    img = img[:, :, (2, 1, 0)]/255.
    img = torch.FloatTensor(img).unsqueeze(0).permute(0, 3, 1, 2)
    out = net(img)
    out = nn.Softmax()(out)
    cls = torch.argmax(out, dim=1)
    piclass = classes[cls]
    print('类别:', classes[cls])
    issendtime = True


#  ****************************************************************************************************
#  网络通信
instring = ''


def client_sent(sock):
    global piclass,issendtime
    while True:
        if issendtime:
            sock.send(piclass.encode('UTF-8'))
            issendtime = False


def client_recv(sock):
    while True:
        global instring
        instring = sock.recv(1024)  # 接收数据
        if instring != '':
            print('图片名字为'+instring.decode('UTF-8'))
            classfypic(instring.decode('UTF-8'))
            instring = ''


clf = socket.socket()
clf.connect(('127.0.0.1', 6666))
print('创建连接')
th_send = threading.Thread(target=client_sent, args=(clf,))  # 发送消息的线程
th_send.start()
th_recv = threading.Thread(target=client_recv, args=(clf,))  # 接受消息线程
th_recv.start()
th_send.join()
th_recv.join()

可以实现通信,只是效率有点低,不知道是不是我的电脑太垃圾的缘故,没有好的显卡不能使用cuda来训练,一训练就是半天时间┗( T﹏T )┛

仅仅实现功能,仅供参考