pytorch c++ 编译

最近在搞pytorch c++ inference,需要编译pytorch lib库。

折腾了几次,终于搞定,这里记录下来。

clone源码

git clone https://github.com/pytorch/pytorch.git

git checkout tags/v1.3.1

git checkout -b v1.3.1

git submodule sync

git submodule update –init –recursive

准备环境,miniconda、cmake(3.5+)、cuda 9.0+

miniconda安装

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh

sh Miniconda3-latest-Linux-x86_64.sh

cmake安装

curl -O https://github.com/Kitware/CMake/releases/download/v3.15.5/cmake-3.15.5.tar.gz

./configure –prefix=install_dir

cuda使用yum安装

安装devtoolset-7,pytorch编译对gcc有要求,不然会出现编译错误,参考https://github.com/pytorch/builder/blob/master/update_compiler.sh

yum install -y -q devtoolset-7-gcc devtoolset-7-gcc-c++ devtoolset-7-gcc-gfortran devtoolset-7-binutils

source /opt/rh/devtoolset-7/enable

源码编译,可以根目录直接make,这里选择手动安装到指定目录

进入torch根目录,

mkdir build

cmake .. -DPYTHON_EXECUTABLE:FILEPATH=/search/privateData/tools/miniconda3/bin/python -DPYTHON_INCLUDE_DIR=/search/privateData/tools/miniconda3/include/python3.7m -DCMAKE_INSTALL_PREFIX=/search/privateData/tools/libtorch-1.3.0

glibc readv和writev函数改进

最近在改进公司内部网络发送库,发现了linux下高级io操作函数readv和writev,在glibc里面实现的。

使用这两函数需要include<sys/uio.h>

ssize_t readv(int fd,const struct iovec *iov, int count); 

从文件描述符fd所对应的的文件中读取count字节大小数据到多个指定顺序buffers中,该buffer用iovec描述

ssize_t writev(int fd,const struct iovec *iov, int count);

把count个指定顺序的数据buffer(使用iovec描述)写入到文件描述符fd所对应的的文件中

struct iovec结构在bits/uio.h中定义的,是一种向量形式的结构体。

/* Structure for scatter/gather I/O.  */
struct iovec
  {
    void *iov_base; /* Pointer to data.  */
    size_t iov_len; /* Length of data.  */
  };

能将本来需要多次发送的数据,聚合在一起,一次发送,提高IO效率。

但使用时发现了一些问题,readv一次不能完全接收到期望长度数据。查看glibc源码,发现readv、writev底层分别是基于read、write实现的,而read一次本来就可能获得不了期望长度数据。

It is not an error if this number is smaller than the number of bytes requested; this may happen for example because fewer bytes are actually available right now (maybe because we were close to end-of- file, or because we are reading from a pipe, or from a terminal), or because read() was interrupted by a signal.

查看glic源码实现,发现里面并没有处理这个问题,所以才数显与期望不一致问题。

以下是glibc readv实现,glibc/sysdeps/posix/writev.c,read只调用了一次。

#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <limits.h>
#include <stdbool.h>
#include <sys/param.h>
#include <sys/uio.h>
#include <errno.h>
static void
ifree (char **ptrp)
{
  free (*ptrp);
}
/* Read data from file descriptor FD, and put the result in the
   buffers described by VECTOR, which is a vector of COUNT 'struct iovec's.
   The buffers are filled in the order specified.
   Operates just like 'read' (see <unistd.h>) except that data are
   put in VECTOR instead of a contiguous buffer.  */
ssize_t
__readv (int fd, const struct iovec *vector, int count)
{
  /* Find the total number of bytes to be read.  */
  size_t bytes = 0;
  for (int i = 0; i < count; ++i)
    {
      /* Check for ssize_t overflow.  */
      if (SSIZE_MAX - bytes < vector[i].iov_len)
        {
          __set_errno (EINVAL);
          return -1;
        }
      bytes += vector[i].iov_len;
    }
  /* Allocate a temporary buffer to hold the data.  We should normally
     use alloca since it's faster and does not require synchronization
     with other threads.  But we cannot if the amount of memory
     required is too large.  */
  char *buffer;
  char *malloced_buffer __attribute__ ((__cleanup__ (ifree))) = NULL;
  if (__libc_use_alloca (bytes))
    buffer = (char *) __alloca (bytes);
  else
    {
      malloced_buffer = buffer = (char *) malloc (bytes);
      if (buffer == NULL)
        return -1;
    }
  /* Read the data.  */
  ssize_t bytes_read = __read (fd, buffer, bytes);
  if (bytes_read < 0)
    return -1;
  /* Copy the data from BUFFER into the memory specified by VECTOR.  */
  bytes = bytes_read;
  for (int i = 0; i < count; ++i)
    {
      size_t copy = MIN (vector[i].iov_len, bytes);
      (void) memcpy ((void *) vector[i].iov_base, (void *) buffer, copy);
      buffer += copy;
      bytes -= copy;
      if (bytes == 0)
        break;
    }
  return bytes_read;
}

所以需要把read调用改进一下,保证数据能读取完整。以下是改进

#define __set_errno(val) (errno = (val))

static void
ifree (char **ptrp)
{
  free (*ptrp);
}

/* Read data from file descriptor FD, and put the result in the
   buffers described by VECTOR, which is a vector of COUNT 'struct iovec's.
   The buffers are filled in the order specified.
   Operates just like 'read' (see <unistd.h>) except that data are
   put in VECTOR instead of a contiguous buffer.  */
ssize_t
my_readv (int fd, const struct iovec *vector, int count)
{
  /* Find the total number of bytes to be read.  */
  size_t bytes = 0;
  for (int i = 0; i < count; ++i)
    {
      /* Check for ssize_t overflow.  */
      if (SSIZE_MAX - bytes < vector[i].iov_len)
	{
	  __set_errno (EINVAL);
	  return -1;
	}
      bytes += vector[i].iov_len;
    }

  /* Allocate a temporary buffer to hold the data.  We should normally
     use alloca since it's faster and does not require synchronization
     with other threads.  But we cannot if the amount of memory
     required is too large.  */
  char *buffer;
  char *malloced_buffer __attribute__ ((__cleanup__ (ifree))) = NULL;
  if (bytes < 128)
    buffer = (char *) alloca (bytes);
  else
    {
      malloced_buffer = buffer = (char *) malloc (bytes);
      if (buffer == NULL)
	return -1;
    }

  /* Read the data.  */
  //ssize_t bytes_read = read (fd, buffer, bytes);

  // recv loop
  int bytes_read = 0, ret = 0;
  while(bytes_read < bytes) {
  	  ret = read (fd, buffer + bytes_read, bytes - bytes_read);
      if(ret > 0) {
          bytes_read += ret;
          continue;
      }
  
      if(ret == 0){
          break;
      } else {
          if(errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
              continue;
          }
          break;
      }
  }

  if (bytes_read < 0)
    return -1;

  /* Copy the data from BUFFER into the memory specified by VECTOR.  */
  bytes = bytes_read;
  for (int i = 0; i < count; ++i)
    {
      size_t copy = MIN (vector[i].iov_len, bytes);

      (void) memcpy ((void *) vector[i].iov_base, (void *) buffer, copy);

      buffer += copy;
      bytes -= copy;
      if (bytes == 0)
	break;
    }
  return bytes_read;
}

完整code见https://github.com/zhangjun/my_notes/blob/master/linux/io

另外facebook  folly也有实现,见 https://github.com/facebook/folly/blob/master/folly/portability/SysUio.cpp

go中len和Count的区别

最近使用go lang, 用着特别爽。但最近开发一个模块,一直不符合预期。后面通过看go源码,发现趟着大坑了。

len或者Count都能获取字符串、字节数组长度。如

data := “hello”

data_len := len(data)

data_len := strings.Count(data, “”) – 1

content := []byte(“hello”)

content_len := len(content)

content_len := bytes.Count(content, nil) – 1

但是len和Count还是有区别。以bytes为例,使用bytes.Count(data, nil) – 1获取字节数组data长度,长度值为utf8.RuneCount(data) + 1,返回的utf-8编码的长度;而len返回实际字节长度。字符串类似。

c++ virtual table

class Base {

public:
	Base(int a): a_(a) {}

	virtual void f() { cout << "Base::f" << endl; }

	virtual void g() { cout << "Base::g" << endl; }

	virtual void h() { cout << "Base::h" << endl; }
private:
	int a_;
};

int main() {
	// your code goes here
	typedef void(*Fun)(void);

	Base b(5);

	Fun pFun = NULL;
	
	cout << "a_:" << *((int*)&b + 2) << endl;

	cout << "虚函数表地址:" << (int*)(&b) << endl;

	int* vtable = (int*)*(int*)(&b);
	cout << "虚函数表 — 第一个函数地址:" << vtable << endl;

	// Invoke the first virtual function

	//pFun = (Fun)*((int*)*(int*)(&b));
	pFun = (Fun)*(vtable);
	pFun();
	pFun = (Fun)*((int*)*(int*)(&b) + 2);
	pFun();
	pFun = (Fun)*(vtable + 4);
	pFun();
        cout << "vtable end: " << (Fun)*(vtable + 6) << std::endl;

       // vtable and end of vtable
       int** pVtab = (int**)&b;
       pFun = (Fun)pVtab[0][0];
       pFun();  // base::f
       pFun = (Fun)pVtab[0][1];
       cout << pFun << endl;
       pFun = (Fun)pVtab[0][2];
       pFun();  // base::g
       pFun = (Fun)pVtab[0][3];
       cout << pFun << endl;
       pFun = (Fun)pVtab[0][4];
       pFun();  // base::h
       pFun = (Fun)pVtab[0][5];
       cout << pFun << endl;
       pFun = (Fun)pVtab[0][6];
       cout << pFun << endl;
       pFun = (Fun)pVtab[0][7];
       cout << pFun << endl;


	return 0;
}
输出结果:

a_:5
虚函数表地址:0x7ffe3f0a5970
虚函数表 ? 第一个函数地址:0x400e70
Base::f
Base::g
Base::h
vtable end: 1
Base::f
0
Base::g
0
Base::h
0
1
1

类static_cast、dynamic_cast与RTTI

1、static_cast

class Base {
    public:
        virtual void f() {
            cout << "Base::f() " << endl;
        }

};

class Derive: public Base {
    public:
        virtual void f() {
            cout << "Derive::f() " << endl;
        }

        virtual void f2() {
            cout << "Derive::f2() " << endl;
        }
};

int main(){
    // static_cast
    Base *pb1 = new Derive();
    Derive *pd1 = static_cast<Derive*>(pb1);
    pd1 -> f();

    Base *pb2 = new Base();
    Derive *pd2 = static_cast<Derive*>(pb2);
    pd2 -> f();
    //pd2 -> f2();  // core, base no f2()

    delete pb1;
    delete pb2;
    return 0;
}

static_cast可以在基类和派生类之间转换(偏移指针),编译时确定,不保证类型转换安全。

2、dynamic_cast

class Base {
    public:
        virtual void f() {
            cout << "Base::f() " << endl;
        }

};

class Derive: public Base {
    public:
        virtual void f() {
            cout << "Derive::f() " << endl;
        }

        virtual void f2() {
            cout << "Derive::f2() " << endl;
        }
};

int main(){
// dynamic_cast
    Base *pb3 = new Derive();
    Derive *pd3 = dynamic_cast<Derive*>(pb3);   // down cast
    pd3 -> f();

    Base *pb4 = new Base();
    Derive *pd4 = dynamic_cast<Derive*>(pb4);   // up cast 
    if(pd4){                    // pd4 is NULL here
        pd4 -> f();
        pd4 -> f2(); 
    }

    delete pb3;
    delete pb4;

    return 0;
}

专门用于用于有继承关系类之间转换,尤其是向下转换,运行时确定,是类型安全的。

3、RTTI

go lang 获取本机ip

方法一

func GetIp() {
    addrs, _ := net.InterfaceAddrs()
    log.Println(addrs)

    for _, addr := range addrs {
        if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
            if ipnet.IP.To4() != nil {
                log.Println(ipnet.IP.String())
            }
        }
    }
}

但是这个方法有个问题,如果本机存在虚拟机、docker等虚拟网卡,这类的ipv4地址同样能获取到

改进下,使用方法二

func GetIpV3() (string, error) {
    addrs, _ := net.Interfaces()
    //log.Println(addrs)

    for _, addr := range addrs {
        if addr.Flags & net.FlagUp == 0 {
            continue
        }

        if addr.Flags & net.FlagLoopback == 1 {
            continue
        }

        if !strings.Contains(addr.Name, "eth") {
            continue
        }

        addr_list, _ := addr.Addrs()
        //log.Println(addr.Name)

        for _, address := range addr_list {
            if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
                if ipnet.IP.To4() != nil {
                    //log.Println(ipnet.IP.String())
                    return ipnet.IP.String(), nil
                }
            }
        }
    }

    return "", errors.New("valid ip not found")
}

对网卡进行过滤。

CentOS7 防火墙配置

最近在CentOS 7上使用rsync,莫名奇妙发现无法连通,最后发现是防火墙打开了。。。
而且CentOS7上防火墙进行了升级,不是使用的iptables
开启防火墙 (关闭stop)

systemctl start firewalld.service

添加端口

firewall-cmd –permanent –zone=public –add-port=873/tcp

(添加后,记得firewall-cmd reload)

一段简单c++代码

最近遇到一个有意思的问题,代码很简单,涉及到的东西会有一些,包括编译、虚函数等。

class A {
    public:
        A() {}
        A(int i) {}
        virtual void print(int j = 1){
            cout << "base: " << j << endl;
        }
};
class B : public A {
    public:
        virtual void print(int j = 4){
            cout << "derived: " << j << endl;
        }
    B() {}
    A a;
};
int main(){
    B b;
    A *a = &b;
    a -> print();
    return 0;
}

上面输出是多少

md5算法原理

MD5还是好久之前学过,早已经忘了。。。最近项目需要,又拾起来看了看,记录下来。
MD5,中文名为消息摘要算法,常用于数据完整校验。通过特定hash散列算法将文本信息转换为简短的消息摘要。
md5以512位分组处理信息,每一分组又被划分为16个32位子分组,变换处理后,输出由4个32位分组组成,级联生成一个128位散列值。
大致过程如下:

  • 填充

假设输入信息长度为len,则以len%512=448公式对数据进行填充

  • 填充消息长度

将消息长度上一步结果的后64位,如果消息长度大于2^64,则取低64位

  • 变换处理

常数 A=0X67452301L,B=0XEFCDAB89L,C=0X98BADCFEL,D=0X10325476L
4个变换函数

F(X,Y,Z)=(X&Y)|((~X)&Z)
G(X,Y,Z)=(X&Z)|(Y&(~Z))
H(X,Y,Z)=X^Y^Z
I(X,Y,Z)=Y^(X|(~Z))

4轮变换的操作

FF(a,b,c,d,Mj,s,ti)表示a=b+((a+F(b,c,d)+Mj+ti)<<<s)
GG(a,b,c,d,Mj,s,ti)表示a=b+((a+G(b,c,d)+Mj+ti)<<<s)
HH(a,b,c,d,Mj,s,ti)表示a=b+((a+H(b,c,d)+Mj+ti)<<<s)
II(a,b,c,d,Mj,s,ti)表示a=b+((a+I(b,c,d)+Mj+ti)<<<s)
Mj表示消息的第j个子分组

每轮变换后A、B、C、D分别加上a、b、c、d,然后进行一下轮