<-- Home |--rust |--matlab

Rust for Matlab里面调用Rust函数

因为能

前面写WASM的时候,发现名为WASM实则是一个C语言的动态链接库。

既然如此,为什么不能用Rust来写一个动态链接库,然后Matlab来调用呢?

虽然,不知道为什么要这么干。实际上,我连用C/C++写一个动态链接库然后Matlab来调用都嫌没必要,最多是写一个可执行文件,然后处理输入到输出,然后matlab调用这个可执行文件。

就当做是没苦硬吃。

那就开始

首先,建立一个Rust的库。

1cargo new --lib rs2m

然后在Cargo.toml中添加

1[lib]
2crate-type = ["cdylib"]

变成:

 1[package]
 2name = "rs2m"
 3version = "0.1.0"
 4edition = "2021"
 5
 6[lib]
 7name = "rs2m"
 8crate-type = ["cdylib"]
 9
10[dependencies]

然后,写一个简单的函数,然后导出。

 1#[no_mangle]
 2pub extern "C" fn add(left: u64, right: u64) -> u64 {
 3    left + right
 4}
 5
 6#[no_mangle]
 7pub extern "C" fn square(x: f64) -> f64 {
 8    x * x
 9}
10
11#[no_mangle]
12pub extern "C" fn linspace(start: f64, end: f64, n: usize, out_ptr: *mut f64) -> i32 {
13    if out_ptr.is_null() {
14        return 0;
15    }
16    
17    let step = if n > 1 {
18        (end - start) / (n - 1) as f64
19    } else {
20        0.0
21    };
22    
23    unsafe {
24        for i in 0..n {
25            *out_ptr.add(i) = start + step * i as f64;
26        }
27    }
28    
29    n as i32
30}

注意到这里,我们每定义一个函数,都要在前面加上#[no_mangle],这个是告诉编译器,这个函数不要进行名字修饰,否则在Matlab中调用的时候会找不到。

然后,在函数的返回值前面加上pub extern "C" fn,这个是告诉编译器,这个函数是导出的,然后"C"表示这个函数是按照C语言的规则来导出的。

普通的值的传递乏善可陈,就是简单的对应关系,可以在Comparing C/C++ types to Rust中找到;然后C/C++和Matlab的对应可以不太在意,因为Matlab是一个工程师的懒人语言,通常不纠结数据类型,除非算不出来。

唯一好玩一点点的就是,浮点数组的传递,这里需要用*mut f64来表示,然后Matlab中需要用libpointer('doublePtr', result)来表示。这个*mut f64是Rust的原始指针类型。

 1#[no_mangle]
 2pub extern "C" fn linspace(start: f64, end: f64, n: usize, out_ptr: *mut f64) -> i32 {
 3    if out_ptr.is_null() {
 4        return 0;
 5    }
 6    
 7    let step = if n > 1 {
 8        (end - start) / (n - 1) as f64
 9    } else {
10        0.0
11    };
12    
13    unsafe {
14        for i in 0..n {
15            *out_ptr.add(i) = start + step * i as f64;
16        }
17    }
18    
19    n as i32
20}

这个代码中就有对原始指针的两个操作,is_null()add()

is_null()是判断指针是否为空,add()是获取指针指向的值。我们使用了unsafe来操作原始指针指向的值。

1cargo build --release

Rust这一边的事情就全部完成。最终我们我们也弄了点测试,假装测试一下。

 1
 2#[cfg(test)]
 3mod tests {
 4    use super::*;
 5
 6    #[test]
 7    fn test_add() {
 8        let result = add(2, 2);
 9        assert_eq!(result, 4);
10    }
11
12    #[test] 
13    fn test_linspace() {
14        let mut output = vec![0.0; 11];
15        let len = linspace(0.0, 1.0, 11, output.as_mut_ptr());
16        assert_eq!(len, 11);
17        for (idx, &ret) in output.iter().enumerate() {
18            let expected = idx as f64 * 0.1;
19            assert!((ret - expected).abs() <= f64::EPSILON);
20        }   
21    }
22
23    #[test]
24    fn test_square() {
25        assert_eq!(square(2.0), 4.0);
26        assert_eq!(square(-3.0), 9.0);
27        assert_eq!(square(0.0), 0.0);
28    }
29}

当然,这里测试中的原始指针就通过Vec<f64>::as_mut_ptr()来获取。

1cargo test

Matlab与C语言的接口

基本上C语言的接口,ABI,是计算机中非常非常通用的。Matlab也不例外,提供了一系列函数来调用。

比较重要的函数有:

  • loadlibrary:加载动态链接库
  • calllib:调用动态链接库中的函数
  • libfunctions:查看动态链接库中的函数
  • libisloaded:查看动态链接库是否加载
  • unloadlibrary:卸载动态链接库
  • libstruct:创建一个结构体
  • libgetptr:获取一个结构体的指针

大概,也就是这么多,可以用Matlab的帮助去查看如何使用。

我们上面的库,定义的函数非常简单,所以头文件也可以非常简单。

1    uint64_t add(uint64_t left, uint64_t right);
2
3    // Generate evenly spaced array
4    int32_t linspace(double start, double end, int32_t n, double *out_ptr);
5
6    // Square function
7    double square(double x);

整个文件:

主要是引用stdint.h,来访问uint64_tint32_t这些更加具有一致性的类型定义。

然后,我们就可以在Matlab中调用这个库了。

1[~, ~] = loadlibrary('rs2m.dll', 'rs2m.h'); 
2calllib('rs2m', 'add', uint64(2), uint64(3));
3arr = libpointer('doublePtr', zeros(1, 11));
4calllib('rs2m', 'linspace', 0, 1, 11, arr);
5calllib('rs2m', 'square', 2);
6% 卸载库
7unloadlibrary('rs2m');

我们通常会包装一层,来方便使用。

 1classdef rs2m
 2    properties (Constant)
 3        LIB_PATH = 'rs2m.dll';
 4        HEADER_PATH = 'rs2m.h';
 5    end
 6    methods(Static)
 7        function ensureLibraryLoaded()
 8            % Get the current directory
 9            current_dir = fileparts(mfilename('fullpath'));
10            % Load the library
11            lib_path = char(fullfile(current_dir, rs2m.LIB_PATH));
12            header_path = char(fullfile(current_dir, rs2m.HEADER_PATH));
13            [~, ~] = loadlibrary(lib_path, header_path, 'mfilename', 'rs2mproto');
14            % Display available functions
15            disp('Available library functions:');
16            libfunctions('rs2m', '-full')
17        end
18        
19        function unloadLibrary()
20            % Unload the library
21            if libisloaded('rs2m')
22                unloadlibrary('rs2m');
23            end
24            persistent isLoaded;
25            if ~isempty(isLoaded)
26                isLoaded = false;
27            end
28        end
29        
30        function result = add(a, b)
31            % Call Rust add function
32            % Parameters:
33            %   a, b: uint64 numbers
34            % Returns:
35            %   result: uint64 result
36            if ~isa(a, 'uint64') || ~isa(b, 'uint64')
37                error('Inputs must be uint64');
38            end
39            
40            result = calllib('rs2m', 'add', a, b);
41        end
42        
43        function result = linspace(start, end_val, n)
44            % Call Rust linspace function
45            % Parameters:
46            %   start: start value
47            %   end_val: end value
48            %   n: number of points
49            % Returns:
50            %   result: array of evenly spaced values
51            
52            
53            % Calculate array length
54            len = n;
55            
56            % Create output array
57            result = zeros(1, len);
58            ptr = libpointer('doublePtr', result);
59            
60            % Call Rust function
61            try
62                actual_len = calllib('rs2m', 'linspace', start, end_val, n, ptr);
63                if actual_len ~= len
64                    warning('Expected length %d but got %d', len, actual_len);
65                end
66                result = ptr.Value;
67            catch ME
68                error('Error calling linspace: %s\nStack: %s', ME.message, getReport(ME, 'extended'));
69            end
70        end
71        
72        function result = square(x)
73            % Call Rust square function
74            % Parameters:
75            %   x: input value
76            % Returns:
77            %   result: x squared
78            
79            result = calllib('rs2m', 'square', x);
80        end
81    end
82end

代码块11中的13行,我们在调用loadlibrary函数时,还提供了一组参数让matlab输出一个m文件,这个m文件可以获得所有方法的接口信息。

测试代码:

 1% Test script for rs2m library
 2rs2m.unloadLibrary();
 3rs2m.ensureLibraryLoaded();
 4% Test add function
 5disp('Testing add function...');
 6a = uint64(2);
 7b = uint64(3);
 8result = rs2m.add(a, b);
 9fprintf('2 + 3 = %d\n', result);
10
11% Test linspace function
12disp('Testing linspace function...');
13result = rs2m.linspace(0, 1, 11);
14fprintf('linspace(0, 1, 11) = %s\n', mat2str(result));
15
16% Test square function
17disp('Testing square function...');
18result = rs2m.square(5);
19fprintf('square(5) = %d\n', result);
20
21% Test square function
22disp('Testing square function...');
23result = rs2m.square(5);
24fprintf('5^2 = %f\n', result);
25
26% Unload library when done
27rs2m.unloadLibrary();

正常运行。C语言的使用方式简单可靠,唯一就是那个libpointer,有点烦人,我一直没找到如何设置数组尺寸,只能通过构造一个初始值的方式,如果这个数组很大,那就太坑了。

接着往下看,Matlab还提供了一套跟C++的接口,叫做clib。稍微看了一下就感觉,这个更有搞头。

Matlab与C++的接口

这个方式来调用编译好的库,貌似需要三个文件:

  • rs2m.h:头文件
  • rs2m.dll:库文件
  • rs2m.lib:库文件的桩

如果是cpp源文件,那就直接使用头文件和源文件即可。当然,可能需要安装个什么c++的编译器。在windows上,安装一个Visual C++社区版,然后就可以使用。我们这里的库文件,cargo build --release,就会生成。头文件跟刚才相同。

1clibgen.generateLibraryDefinition("rs2m.h", Libraries="rs2m.dll", OutputFolder="rs2mlib-cpp");

这个命令会生成一个rs2mlib-cpp的文件夹,里面包含一个definers2m.m文件和另外一个xml文件,这个文件就是我们要的包装文件。

这里唯一一个坑,就是definers2m.m文件中,有些函数的定义需要补充。对于我们这几个函数,就是linspace函数,需要我们自己补充参数的维度。

 1%% C++ function |linspace| with MATLAB name |clib.rs2m.linspace|
 2% C++ Signature: int32_t linspace(double start,double end,int32_t n,double * out_ptr)
 3
 4%linspaceDefinition = addFunction(libDef, ...
 5%    "int32_t linspace(double start,double end,int32_t n,double * out_ptr)", ...
 6%    "MATLABName", "clib.rs2m.linspace", ...
 7%    "Description", "clib.rs2m.linspace Representation of C++ function linspace."); % Modify help description values as needed.
 8%defineArgument(linspaceDefinition, "start", "double");
 9%defineArgument(linspaceDefinition, "end", "double");
10%defineArgument(linspaceDefinition, "n", "int32");
11%defineArgument(linspaceDefinition, "out_ptr", "clib.array.rs2m.Double", "input", <SHAPE>); % <MLTYPE> can be "clib.array.rs2m.Double", or "double"
12%defineOutput(linspaceDefinition, "RetVal", "int32");
13%validate(linspaceDefinition);

我么需要把这个注释去掉,然后把<SHAPE>补充完整。

1defineArgument(linspaceDefinition, "out_ptr", "clib.array.rs2m.Double", "input", "n"); % <MLTYPE> can be "clib.array.rs2m.Double", or "double"

注意这里,用的"n",也就是原来方法的第三个参数,实际上,定义了这个尺寸之后,clib.rs2m.linspace的第三个参数n就被省略了,直接包含在第四个参数中。

然后就可以使用build命令来生成dll文件了。

1build(definers2m);      

这个命令会生成一个dll文件:rs2mInterface.dll,有了这两个dll文件,就可以在Matlab中调用这个库了。当然调用之前,建议先做一个设置:

1clibConfiguration('rs2m', 'ExecutionMode', 'outofprocess');

如果运行模式为inprocess,则需要重新启动Matlab才能卸载该库,当然我们使用了outofprocess,所以不需要重新启动Matlab就可以卸载该库。

1clibConfiguration('rs2m').unload();

然后就可以在Matlab中调用这个库了。

 1
 2doc clibgen.generateLibraryDefinition
 3
 4clibgen.generateLibraryDefinition("rs2m.h", Libraries="rs2m.dll", OutputFolder="rs2mlib-cpp");
 5
 6%% find functions
 7help clib.rs2m
 8
 9%% test add
10help clib.rs2m.add
11a = uint64(2);
12b = uint64(3);
13result = clib.rs2m.add(a, b);
14fprintf('2 + 3 = %d\n', result);
15
16%% test linspace
17help clib.rs2m.linspace
18start = 0;
19end_val = 1;
20n = 11;
21arr = clibArray("clib.rs2m.Double", n);
22result = clib.rs2m.linspace(start, end_val, arr);
23fprintf('linspace(0, 1, 11) = %s\n', mat2str(result));
24
25%% test square
26help clib.rs2m.square
27x = 2;
28result = clib.rs2m.square(x);
29fprintf('square(2) = %d\n', result);

正常运行。

总结

忙活挺长时间才搞定这个,Matlab真是无所不能的。


文章标签

|-->rust |-->matlab |-->c |-->abi |-->pointer


GitHub