Skip to content

Commit c7ee230

Browse files
committed
fix ToMultiDimArray
1 parent 94601f5 commit c7ee230

File tree

5 files changed

+100
-23
lines changed

5 files changed

+100
-23
lines changed

src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,62 @@ static T Scalar<T>(long input)
3030
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
3131
_ => throw new NotImplementedException("")
3232
};
33+
34+
public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
35+
{
36+
var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list());
37+
38+
var addr = ret switch
39+
{
40+
T[] array => Addr(array),
41+
T[,] array => Addr(array),
42+
T[,,] array => Addr(array),
43+
T[,,,] array => Addr(array),
44+
T[,,,,] array => Addr(array),
45+
T[,,,,,] array => Addr(array),
46+
_ => throw new NotImplementedException("")
47+
};
48+
49+
System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);
50+
return ret;
51+
}
52+
53+
#region multiple array
54+
static unsafe T* Addr<T>(T[] array) where T : unmanaged
55+
{
56+
fixed (T* a = &array[0])
57+
return a;
58+
}
59+
60+
static unsafe T* Addr<T>(T[,] array) where T : unmanaged
61+
{
62+
fixed (T* a = &array[0, 0])
63+
return a;
64+
}
65+
66+
static unsafe T* Addr<T>(T[,,] array) where T : unmanaged
67+
{
68+
fixed (T* a = &array[0, 0, 0])
69+
return a;
70+
}
71+
72+
static unsafe T* Addr<T>(T[,,,] array) where T : unmanaged
73+
{
74+
fixed (T* a = &array[0, 0, 0, 0])
75+
return a;
76+
}
77+
78+
static unsafe T* Addr<T>(T[,,,,] array) where T : unmanaged
79+
{
80+
fixed (T* a = &array[0, 0, 0, 0, 0])
81+
return a;
82+
}
83+
84+
static unsafe T* Addr<T>(T[,,,,,] array) where T : unmanaged
85+
{
86+
fixed (T* a = &array[0, 0, 0, 0, 0, 0])
87+
return a;
88+
}
89+
#endregion
3390
}
3491
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,28 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray(bool value) : base(value) { NewEagerTensorHandle(); }
12-
public NDArray(byte value) : base(value) { NewEagerTensorHandle(); }
13-
public NDArray(short value) : base(value) { NewEagerTensorHandle(); }
14-
public NDArray(int value) : base(value) { NewEagerTensorHandle(); }
15-
public NDArray(long value) : base(value) { NewEagerTensorHandle(); }
16-
public NDArray(float value) : base(value) { NewEagerTensorHandle(); }
17-
public NDArray(double value) : base(value) { NewEagerTensorHandle(); }
11+
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
12+
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
13+
public NDArray(short value) : base(value) => NewEagerTensorHandle();
14+
public NDArray(int value) : base(value) => NewEagerTensorHandle();
15+
public NDArray(long value) : base(value) => NewEagerTensorHandle();
16+
public NDArray(float value) : base(value) => NewEagerTensorHandle();
17+
public NDArray(double value) : base(value) => NewEagerTensorHandle();
1818

19-
public NDArray(Array value, Shape? shape = null)
20-
: base(value, shape) { NewEagerTensorHandle(); }
19+
public NDArray(Array value, Shape? shape = null) : base(value, shape)
20+
=> NewEagerTensorHandle();
2121

22-
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
23-
: base(shape, dtype: dtype) { NewEagerTensorHandle(); }
22+
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) : base(shape, dtype: dtype)
23+
=> NewEagerTensorHandle();
2424

25-
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
26-
: base(bytes, shape, dtype) { NewEagerTensorHandle(); }
25+
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype)
26+
=> NewEagerTensorHandle();
2727

28-
public NDArray(long[] value, Shape? shape = null)
29-
: base(value, shape) { NewEagerTensorHandle(); }
28+
public NDArray(long[] value, Shape? shape = null) : base(value, shape)
29+
=> NewEagerTensorHandle();
3030

31-
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
32-
: base(address, shape, dtype) { NewEagerTensorHandle(); }
31+
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) : base(address, shape, dtype)
32+
=> NewEagerTensorHandle();
3333

3434
public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone)
3535
{

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Collections.Generic;
2020
using System.Linq;
2121
using System.Text;
22+
using Tensorflow.Util;
2223
using static Tensorflow.Binding;
2324

2425
namespace Tensorflow.NumPy
@@ -35,7 +36,10 @@ public ValueType GetValue(params int[] indices)
3536
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
3637
public NDArray ravel() => throw new NotImplementedException("");
3738
public void shuffle(NDArray nd) => np.random.shuffle(nd);
38-
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");
39+
40+
public unsafe Array ToMultiDimArray<T>() where T : unmanaged
41+
=> NDArrayConverter.ToMultiDimArray<T>(this);
42+
3943
public byte[] ToByteArray() => BufferToArray();
4044
public override string ToString() => NDArrayRender.ToString(this);
4145

src/TensorFlowNET.Keras/Saving/hdf5_format.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,19 +273,19 @@ private static void WriteDataset(long f, string name, Tensor data)
273273
switch (data.dtype)
274274
{
275275
case TF_DataType.TF_FLOAT:
276-
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
276+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
277277
break;
278278
case TF_DataType.TF_DOUBLE:
279-
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMuliDimArray<double>());
279+
Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMultiDimArray<double>());
280280
break;
281281
case TF_DataType.TF_INT32:
282-
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMuliDimArray<int>());
282+
Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMultiDimArray<int>());
283283
break;
284284
case TF_DataType.TF_INT64:
285-
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMuliDimArray<long>());
285+
Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMultiDimArray<long>());
286286
break;
287287
default:
288-
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMuliDimArray<float>());
288+
Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
289289
break;
290290
}
291291
}

test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ public void array()
5050
AssetSequenceEqual(new[] { 1, 2, 3, 4, 5, 6 }, x.ToArray<int>());
5151
}
5252

53+
[TestMethod]
54+
public void to_multi_dim_array()
55+
{
56+
var x1 = np.arange(12);
57+
var y1 = x1.ToMultiDimArray<int>();
58+
AssetSequenceEqual((int[])y1, x1.ToArray<int>());
59+
60+
var x2 = np.arange(12).reshape((2, 6));
61+
var y2 = (int[,])x2.ToMultiDimArray<int>();
62+
Assert.AreEqual(x2[0, 5], y2[0, 5]);
63+
64+
var x3 = np.arange(12).reshape((2, 2, 3));
65+
var y3 = (int[,,])x3.ToMultiDimArray<int>();
66+
Assert.AreEqual(x3[0, 1, 2], y3[0, 1, 2]);
67+
}
68+
5369
[TestMethod]
5470
public void eye()
5571
{

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy