Skip to content

Commit 32051cb

Browse files
BadSingletonfilmor
andauthored
Expose serialization api (#2336)
* Expose an API for users to specify their own formatter Adds post-serialization and pre-deserialization hooks for additional customization. * Add API for capsuling data when serializing * Add NoopFormatter and fall back to it if BinaryFormatter is not available --------- Co-authored-by: Benedikt Reinartz <filmor@gmail.com>
1 parent 195cde6 commit 32051cb

File tree

5 files changed

+269
-6
lines changed

5 files changed

+269
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
1515
to compare with primitive .NET types like `long`.
1616

1717
### Changed
18+
- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType`
19+
- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process
20+
- Added an API to stash serialized data on Python capsules
1821

1922
### Fixed
2023

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.IO;
3+
using System.Runtime.Serialization;
4+
5+
namespace Python.Runtime;
6+
7+
public class NoopFormatter : IFormatter {
8+
public object Deserialize(Stream s) => throw new NotImplementedException();
9+
public void Serialize(Stream s, object o) {}
10+
11+
public SerializationBinder? Binder { get; set; }
12+
public StreamingContext Context { get; set; }
13+
public ISurrogateSelector? SurrogateSelector { get; set; }
14+
}

src/runtime/StateSerialization/RuntimeData.cs

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
4-
using System.Collections.ObjectModel;
53
using System.Diagnostics;
64
using System.IO;
75
using System.Linq;
@@ -17,7 +15,34 @@ namespace Python.Runtime
1715
{
1816
public static class RuntimeData
1917
{
20-
private static Type? _formatterType;
18+
19+
public readonly static Func<IFormatter> DefaultFormatterFactory = () =>
20+
{
21+
try
22+
{
23+
return new BinaryFormatter();
24+
}
25+
catch
26+
{
27+
return new NoopFormatter();
28+
}
29+
};
30+
31+
private static Func<IFormatter> _formatterFactory { get; set; } = DefaultFormatterFactory;
32+
33+
public static Func<IFormatter> FormatterFactory
34+
{
35+
get => _formatterFactory;
36+
set
37+
{
38+
if (value == null)
39+
throw new ArgumentNullException(nameof(value));
40+
41+
_formatterFactory = value;
42+
}
43+
}
44+
45+
private static Type? _formatterType = null;
2146
public static Type? FormatterType
2247
{
2348
get => _formatterType;
@@ -31,6 +56,14 @@ public static Type? FormatterType
3156
}
3257
}
3358

59+
/// <summary>
60+
/// Callback called as a last step in the serialization process
61+
/// </summary>
62+
public static Action? PostStashHook { get; set; } = null;
63+
/// <summary>
64+
/// Callback called as the first step in the deserialization process
65+
/// </summary>
66+
public static Action? PreRestoreHook { get; set; } = null;
3467
public static ICLRObjectStorer? WrappersStorer { get; set; }
3568

3669
/// <summary>
@@ -74,6 +107,7 @@ internal static void Stash()
74107
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
75108
int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow());
76109
PythonException.ThrowIfIsNotZero(res);
110+
PostStashHook?.Invoke();
77111
}
78112

79113
internal static void RestoreRuntimeData()
@@ -90,6 +124,7 @@ internal static void RestoreRuntimeData()
90124

91125
private static void RestoreRuntimeDataImpl()
92126
{
127+
PreRestoreHook?.Invoke();
93128
BorrowedReference capsule = PySys_GetObject("clr_data");
94129
if (capsule.IsNull)
95130
{
@@ -250,11 +285,102 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage)
250285
}
251286
}
252287

288+
static readonly string serialization_key_namepsace = "pythonnet_serialization_";
289+
/// <summary>
290+
/// Removes the serialization capsule from the `sys` module object.
291+
/// </summary>
292+
/// <remarks>
293+
/// The serialization data must have been set with <code>StashSerializationData</code>
294+
/// </remarks>
295+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
296+
public static void FreeSerializationData(string key)
297+
{
298+
key = serialization_key_namepsace + key;
299+
BorrowedReference oldCapsule = PySys_GetObject(key);
300+
if (!oldCapsule.IsNull)
301+
{
302+
IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero);
303+
Marshal.FreeHGlobal(oldData);
304+
PyCapsule_SetPointer(oldCapsule, IntPtr.Zero);
305+
PySys_SetObject(key, null);
306+
}
307+
}
308+
309+
/// <summary>
310+
/// Stores the data in the <paramref name="stream"/> argument in a Python capsule and stores
311+
/// the capsule on the `sys` module object with the name <paramref name="key"/>.
312+
/// </summary>
313+
/// <remarks>
314+
/// No checks on pre-existing names on the `sys` module object are made.
315+
/// </remarks>
316+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
317+
/// <param name="stream">A MemoryStream that contains the data to be placed in the capsule</param>
318+
public static void StashSerializationData(string key, MemoryStream stream)
319+
{
320+
if (stream.TryGetBuffer(out var data))
321+
{
322+
IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Count);
323+
324+
// store the length of the buffer first
325+
Marshal.WriteIntPtr(mem, (IntPtr)data.Count);
326+
Marshal.Copy(data.Array, data.Offset, mem + IntPtr.Size, data.Count);
327+
328+
try
329+
{
330+
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
331+
int res = PySys_SetObject(key, capsule.BorrowOrThrow());
332+
PythonException.ThrowIfIsNotZero(res);
333+
}
334+
catch
335+
{
336+
Marshal.FreeHGlobal(mem);
337+
}
338+
}
339+
else
340+
{
341+
throw new NotImplementedException($"{nameof(stream)} must be exposable");
342+
}
343+
344+
}
345+
346+
static byte[] emptyBuffer = new byte[0];
347+
/// <summary>
348+
/// Retreives the previously stored data on a Python capsule.
349+
/// Throws if the object corresponding to the <paramref name="key"/> parameter
350+
/// on the `sys` module object is not a capsule.
351+
/// </summary>
352+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
353+
/// <returns>A MemoryStream containing the previously saved serialization data.
354+
/// The stream is empty if no name matches the key. </returns>
355+
public static MemoryStream GetSerializationData(string key)
356+
{
357+
BorrowedReference capsule = PySys_GetObject(key);
358+
if (capsule.IsNull)
359+
{
360+
// nothing to do.
361+
return new MemoryStream(emptyBuffer, writable:false);
362+
}
363+
var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero);
364+
if (ptr == IntPtr.Zero)
365+
{
366+
// The PyCapsule API returns NULL on error; NULL cannot be stored
367+
// as a capsule's value
368+
PythonException.ThrowIfIsNull(null);
369+
}
370+
var len = (int)Marshal.ReadIntPtr(ptr);
371+
byte[] buffer = new byte[len];
372+
Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len);
373+
return new MemoryStream(buffer, writable:false);
374+
}
375+
253376
internal static IFormatter CreateFormatter()
254377
{
255-
return FormatterType != null ?
256-
(IFormatter)Activator.CreateInstance(FormatterType)
257-
: new BinaryFormatter();
378+
379+
if (FormatterType != null)
380+
{
381+
return (IFormatter)Activator.CreateInstance(FormatterType);
382+
}
383+
return FormatterFactory();
258384
}
259385
}
260386
}

tests/domain_tests/TestRunner.cs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,66 @@ import System
11321132
11331133
",
11341134
},
1135+
new TestCase
1136+
{
1137+
Name = "test_serialize_unserializable_object",
1138+
DotNetBefore = @"
1139+
namespace TestNamespace
1140+
{
1141+
public class NotSerializableTextWriter : System.IO.TextWriter
1142+
{
1143+
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
1144+
}
1145+
[System.Serializable]
1146+
public static class SerializableWriter
1147+
{
1148+
private static System.IO.TextWriter _writer = null;
1149+
public static System.IO.TextWriter Writer {get { return _writer; }}
1150+
public static void CreateInternalWriter()
1151+
{
1152+
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
1153+
}
1154+
}
1155+
}
1156+
",
1157+
DotNetAfter = @"
1158+
namespace TestNamespace
1159+
{
1160+
public class NotSerializableTextWriter : System.IO.TextWriter
1161+
{
1162+
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
1163+
}
1164+
[System.Serializable]
1165+
public static class SerializableWriter
1166+
{
1167+
private static System.IO.TextWriter _writer = null;
1168+
public static System.IO.TextWriter Writer {get { return _writer; }}
1169+
public static void CreateInternalWriter()
1170+
{
1171+
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
1172+
}
1173+
}
1174+
}
1175+
",
1176+
PythonCode = @"
1177+
import sys
1178+
1179+
def before_reload():
1180+
import clr
1181+
import System
1182+
clr.AddReference('DomainTests')
1183+
import TestNamespace
1184+
TestNamespace.SerializableWriter.CreateInternalWriter();
1185+
sys.__obj = TestNamespace.SerializableWriter.Writer
1186+
sys.__obj.WriteLine('test')
1187+
1188+
def after_reload():
1189+
import clr
1190+
import System
1191+
sys.__obj.WriteLine('test')
1192+
1193+
",
1194+
}
11351195
};
11361196

11371197
/// <summary>
@@ -1142,7 +1202,59 @@ import System
11421202
const string CaseRunnerTemplate = @"
11431203
using System;
11441204
using System.IO;
1205+
using System.Runtime.Serialization;
1206+
using System.Runtime.Serialization.Formatters.Binary;
11451207
using Python.Runtime;
1208+
1209+
namespace Serialization
1210+
{{
1211+
// Classes in this namespace is mostly useful for test_serialize_unserializable_object
1212+
class NotSerializableSerializer : ISerializationSurrogate
1213+
{{
1214+
public NotSerializableSerializer()
1215+
{{
1216+
}}
1217+
public void GetObjectData(object obj, SerializationInfo info, StreamingContext context)
1218+
{{
1219+
info.AddValue(""notSerialized_tp"", obj.GetType());
1220+
}}
1221+
public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector)
1222+
{{
1223+
if (info == null)
1224+
{{
1225+
return null;
1226+
}}
1227+
Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type;
1228+
if (typeObj == null)
1229+
{{
1230+
return null;
1231+
}}
1232+
1233+
obj = Activator.CreateInstance(typeObj);
1234+
return obj;
1235+
}}
1236+
}}
1237+
class NonSerializableSelector : SurrogateSelector
1238+
{{
1239+
public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector)
1240+
{{
1241+
if (type == null)
1242+
{{
1243+
throw new ArgumentNullException();
1244+
}}
1245+
selector = (ISurrogateSelector)this;
1246+
if (type.IsSerializable)
1247+
{{
1248+
return null; // use whichever default
1249+
}}
1250+
else
1251+
{{
1252+
return (ISerializationSurrogate)(new NotSerializableSerializer());
1253+
}}
1254+
}}
1255+
}}
1256+
}}
1257+
11461258
namespace CaseRunner
11471259
{{
11481260
class CaseRunner
@@ -1151,6 +1263,11 @@ public static int Main()
11511263
{{
11521264
try
11531265
{{
1266+
RuntimeData.FormatterFactory = () =>
1267+
{{
1268+
return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}};
1269+
}};
1270+
11541271
PythonEngine.Initialize();
11551272
using (Py.GIL())
11561273
{{

tests/domain_tests/test_domain_reload.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,6 @@ def test_nested_type():
8888

8989
def test_import_after_reload():
9090
_run_test("import_after_reload")
91+
92+
def test_import_after_reload():
93+
_run_test("test_serialize_unserializable_object")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy