Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. #23261

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="MSBuild.Sdk.Extras/3.0.22">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!--- packaging properties -->
<OrtPackageId Condition="'$(OrtPackageId)' == ''">Microsoft.ML.OnnxRuntime</OrtPackageId>
Expand Down Expand Up @@ -184,6 +184,10 @@
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
</ItemGroup>

<!-- debug output - makes finding/fixing any issues with the the conditions easy. -->
<Target Name="DumpValues" BeforeTargets="PreBuildEvent">
<Message Text="SolutionName='$(SolutionName)'" />
Expand Down
152 changes: 152 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
using System.Runtime.InteropServices;
using System.Text;

#if NET8_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using SystemNumericsTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
Expand Down Expand Up @@ -205,6 +213,33 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
/// <summary>
/// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that
/// provides a read-only view.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
}
#endif

/// <summary>
/// Returns a Span<typeparamref name="T"/> over tensor native buffer.
/// This enables you to safely and efficiently modify the underlying
Expand All @@ -225,6 +260,32 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
/// <summary>
/// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.TensorSpan<T>(typeSpan, nArray, []);
}
#endif

/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
Expand All @@ -234,6 +295,23 @@ public Span<byte> GetTensorMutableRawData()
return GetTensorBufferRawData(typeof(byte));
}

#if NET8_0_OR_GREATER
/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
/// <returns>TensorSpan over the native buffer bytes</returns>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.TensorSpan<byte>(byteSpan, nArray, []);
}
#endif

/// <summary>
/// Fetch string tensor element buffer pointer at the specified index,
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
Expand Down Expand Up @@ -605,6 +683,80 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape);
}

#if NET8_0_OR_GREATER
/// <summary>
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor on top of the existing tensor managed memory.
/// The method will attempt to pin managed memory so no copying occurs when data is passed down
/// to native code.
/// </summary>
/// <param name="value">Tensor object</param>
/// <param name="elementType">discovered tensor element type</param>
/// <returns>And instance of OrtValue constructed on top of the object</returns>
[Experimental("SYSLIB5001")]
public static OrtValue CreateTensorValueFromSystemNumericsTensorObject<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
{
if (!IsContiguousAndDense(tensor))
{
var newTensor = SystemNumericsTensors.Tensor.Create<T>(tensor.Lengths);
tensor.CopyTo(newTensor);
tensor = newTensor;
}
unsafe
{
var backingData = (T[])tensor.GetType().GetField("_values", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(tensor);
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned);
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle);

try
{
IntPtr dataBufferPointer = IntPtr.Zero;
unsafe
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
{
dataBufferPointer = (IntPtr)memHandle.Pointer;
}

var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));

var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");

NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
OrtMemoryInfo.DefaultInstance.Pointer,
dataBufferPointer,
(UIntPtr)(bufferLengthInBytes),
shape,
(UIntPtr)tensor.Rank,
typeInfo.ElementType,
out IntPtr nativeValue));

return new OrtValue(nativeValue, memHandle);
}
catch (Exception)
{
memHandle.Dispose();
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
throw;
}
}
}

[Experimental("SYSLIB5001")]
private static bool IsContiguousAndDense<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
{
// Right most dimension must be 1 for a dense tensor.
if (tensor.Strides[^1] != 1)
return false;

// For other dimensions, the stride must be equal to the product of the dimensions to the right.
for (int i = tensor.Rank - 2; i >= 0; i--)
{
if (tensor.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1)))
return false;
}
return true;
}
#endif

/// <summary>
/// The factory API creates an OrtValue with memory allocated using the given allocator
/// according to the specified shape and element type. The memory will be released when OrtValue
Expand Down
Loading
Loading