Browse Source

Merge pull request #72 from auz34/master

Additional protection from possible stackoverflow
Sébastien Ros 11 years ago
parent
commit
66959b3a18

+ 147 - 1
Jint.Tests/Runtime/EngineTests.cs

@@ -604,7 +604,7 @@ namespace Jint.Tests.Runtime
                 () => new Engine(cfg => cfg.MaxStatements(100)).Execute("while(true);")
                 () => new Engine(cfg => cfg.MaxStatements(100)).Execute("while(true);")
             );
             );
         }
         }
-
+        
         [Fact]
         [Fact]
         public void ShouldThrowTimeout()
         public void ShouldThrowTimeout()
         {
         {
@@ -613,6 +613,152 @@ namespace Jint.Tests.Runtime
             );
             );
         }
         }
 
 
+
+        [Fact]
+        public void CanDiscardRecursion()
+        {
+            var script = @"var factorial = function(n) {
+                if (n>1) {
+                    return n * factorial(n - 1);
+                }
+            };
+
+            var result = factorial(500);
+            ";
+
+            Assert.Throws<RecursionDepthOverflowException>(
+                () => new Engine(cfg => cfg.LimitRecursion()).Execute(script)
+            );
+        }
+
+        [Fact]
+        public void ShouldDiscardHiddenRecursion()
+        {
+            var script = @"var renamedFunc;
+            var exec = function(callback) {
+                renamedFunc = callback;
+                callback();
+            };
+
+            var result = exec(function() {
+                renamedFunc();
+            });
+            ";
+
+            Assert.Throws<RecursionDepthOverflowException>(
+                () => new Engine(cfg => cfg.LimitRecursion()).Execute(script)
+            );
+        }
+
+        [Fact]
+        public void ShouldRecognizeAndDiscardChainedRecursion()
+        {
+            var script = @" var funcRoot, funcA, funcB, funcC, funcD;
+
+            var funcRoot = function() {
+                funcA();
+            };
+ 
+            var funcA = function() {
+                funcB();
+            };
+
+            var funcB = function() {
+                funcC();
+            };
+
+            var funcC = function() {
+                funcD();
+            };
+
+            var funcD = function() {
+                funcRoot();
+            };
+
+            funcRoot();
+            ";
+
+            Assert.Throws<RecursionDepthOverflowException>(
+                () => new Engine(cfg => cfg.LimitRecursion()).Execute(script)
+            );
+        }
+
+        [Fact]
+        public void ShouldProvideCallChainWhenDiscardRecursion()
+        {
+            var script = @" var funcRoot, funcA, funcB, funcC, funcD;
+
+            var funcRoot = function() {
+                funcA();
+            };
+ 
+            var funcA = function() {
+                funcB();
+            };
+
+            var funcB = function() {
+                funcC();
+            };
+
+            var funcC = function() {
+                funcD();
+            };
+
+            var funcD = function() {
+                funcRoot();
+            };
+
+            funcRoot();
+            ";
+
+            RecursionDepthOverflowException exception = null;
+
+            try
+            {
+                new Engine(cfg => cfg.LimitRecursion()).Execute(script);
+            }
+            catch (RecursionDepthOverflowException ex)
+            {
+                exception = ex;
+            }
+
+            Assert.NotNull(exception);
+            Assert.Equal("funcRoot->funcA->funcB->funcC->funcD", exception.CallChain);
+            Assert.Equal("funcRoot", exception.CallExpressionReference);
+        }
+
+        [Fact]
+        public void ShouldAllowShallowRecursion()
+        {
+            var script = @"var factorial = function(n) {
+                if (n>1) {
+                    return n * factorial(n - 1);
+                }
+            };
+
+            var result = factorial(8);
+            ";
+
+            new Engine(cfg => cfg.LimitRecursion(20)).Execute(script);
+        }
+
+        [Fact]
+        public void ShouldDiscardDeepRecursion()
+        {
+            var script = @"var factorial = function(n) {
+                if (n>1) {
+                    return n * factorial(n - 1);
+                }
+            };
+
+            var result = factorial(38);
+            ";
+
+            Assert.Throws<RecursionDepthOverflowException>(
+                () => new Engine(cfg => cfg.LimitRecursion(20)).Execute(script)
+            );
+        }
+
         [Fact]
         [Fact]
         public void ShouldConvertDoubleToStringWithoutLosingPrecision()
         public void ShouldConvertDoubleToStringWithoutLosingPrecision()
         {
         {

+ 15 - 2
Jint/Engine.cs

@@ -25,6 +25,8 @@ using Jint.Runtime.References;
 
 
 namespace Jint
 namespace Jint
 {
 {
+    using Jint.Runtime.CallStack;
+
     public class Engine
     public class Engine
     {
     {
         private readonly ExpressionInterpreter _expressions;
         private readonly ExpressionInterpreter _expressions;
@@ -38,7 +40,9 @@ namespace Jint
         public ITypeConverter ClrTypeConverter;
         public ITypeConverter ClrTypeConverter;
 
 
         // cache of types used when resolving CLR type names
         // cache of types used when resolving CLR type names
-        internal Dictionary<string, Type> TypeCache = new Dictionary<string, Type>(); 
+        internal Dictionary<string, Type> TypeCache = new Dictionary<string, Type>();
+
+        internal JintCallStack CallStack = new JintCallStack();
 
 
         public Engine() : this(null)
         public Engine() : this(null)
         {
         {
@@ -220,13 +224,21 @@ namespace Jint
         {
         {
             _statementsCount = 0;
             _statementsCount = 0;
         }
         }
-
+        
         public void ResetTimeoutTicks()
         public void ResetTimeoutTicks()
         {
         {
             var timeoutIntervalTicks = Options.GetTimeoutInterval().Ticks;
             var timeoutIntervalTicks = Options.GetTimeoutInterval().Ticks;
             _timeoutTicks = timeoutIntervalTicks > 0 ? DateTime.UtcNow.Ticks + timeoutIntervalTicks : 0;
             _timeoutTicks = timeoutIntervalTicks > 0 ? DateTime.UtcNow.Ticks + timeoutIntervalTicks : 0;
         }
         }
 
 
+        /// <summary>
+        /// Initializes list of references of called functions
+        /// </summary>
+        public void ResetCallStack()
+        {
+            CallStack.Clear();
+        }
+
         public Engine Execute(string source)
         public Engine Execute(string source)
         {
         {
             var parser = new JavaScriptParser();
             var parser = new JavaScriptParser();
@@ -244,6 +256,7 @@ namespace Jint
             ResetStatementsCount();
             ResetStatementsCount();
             ResetTimeoutTicks();
             ResetTimeoutTicks();
             ResetLastStatement();
             ResetLastStatement();
+            ResetCallStack();
 
 
             using (new StrictModeScope(Options.IsStrict() || program.Strict))
             using (new StrictModeScope(Options.IsStrict() || program.Strict))
             {
             {

+ 4 - 0
Jint/Jint.csproj

@@ -159,6 +159,9 @@
     <Compile Include="Parser\Token.cs" />
     <Compile Include="Parser\Token.cs" />
     <Compile Include="Properties\AssemblyInfo.cs" />
     <Compile Include="Properties\AssemblyInfo.cs" />
     <Compile Include="Runtime\Arguments.cs" />
     <Compile Include="Runtime\Arguments.cs" />
+    <Compile Include="Runtime\CallStack\JintCallStack.cs" />
+    <Compile Include="Runtime\CallStack\CallStackElementComparer.cs" />
+    <Compile Include="Runtime\CallStack\CallStackElement.cs" />
     <Compile Include="Runtime\Completion.cs" />
     <Compile Include="Runtime\Completion.cs" />
     <Compile Include="Runtime\Descriptors\PropertyDescriptor.cs" />
     <Compile Include="Runtime\Descriptors\PropertyDescriptor.cs" />
     <Compile Include="Runtime\Descriptors\Specialized\FieldInfoDescriptor.cs" />
     <Compile Include="Runtime\Descriptors\Specialized\FieldInfoDescriptor.cs" />
@@ -186,6 +189,7 @@
     <Compile Include="Runtime\Interop\TypeReference.cs" />
     <Compile Include="Runtime\Interop\TypeReference.cs" />
     <Compile Include="Runtime\Interop\TypeReferencePrototype.cs" />
     <Compile Include="Runtime\Interop\TypeReferencePrototype.cs" />
     <Compile Include="Runtime\JavaScriptException.cs" />
     <Compile Include="Runtime\JavaScriptException.cs" />
+    <Compile Include="Runtime\RecursionDepthOverflowException.cs" />
     <Compile Include="Runtime\References\Reference.cs" />
     <Compile Include="Runtime\References\Reference.cs" />
     <Compile Include="Runtime\StatementInterpreter.cs" />
     <Compile Include="Runtime\StatementInterpreter.cs" />
     <Compile Include="Runtime\StatementsCountOverflowException.cs" />
     <Compile Include="Runtime\StatementsCountOverflowException.cs" />

+ 22 - 1
Jint/Options.cs

@@ -16,6 +16,7 @@ namespace Jint
         private bool _allowClr;
         private bool _allowClr;
         private readonly List<IObjectConverter> _objectConverters = new List<IObjectConverter>();
         private readonly List<IObjectConverter> _objectConverters = new List<IObjectConverter>();
         private int _maxStatements;
         private int _maxStatements;
+        private int _maxRecursionDepth = -1; 
         private TimeSpan _timeoutInterval;
         private TimeSpan _timeoutInterval;
         private CultureInfo _culture = CultureInfo.CurrentCulture;
         private CultureInfo _culture = CultureInfo.CurrentCulture;
         private TimeZoneInfo _localTimeZone = TimeZoneInfo.Local;
         private TimeZoneInfo _localTimeZone = TimeZoneInfo.Local;
@@ -78,13 +79,28 @@ namespace Jint
             _maxStatements = maxStatements;
             _maxStatements = maxStatements;
             return this;
             return this;
         }
         }
-
+        
         public Options TimeoutInterval(TimeSpan timeoutInterval)
         public Options TimeoutInterval(TimeSpan timeoutInterval)
         {
         {
             _timeoutInterval = timeoutInterval;
             _timeoutInterval = timeoutInterval;
             return this;
             return this;
         }
         }
 
 
+        /// <summary>
+        /// Sets maximum allowed depth of recursion.
+        /// </summary>
+        /// <param name="maxRecursionDepth">
+        /// The allowed depth.
+        /// a) In case max depth is zero no recursion is allowed.
+        /// b) In case max depth is equal to n it means that in one scope function can be called no more than n times.
+        /// </param>
+        /// <returns>Options instance for fluent syntax</returns>
+        public Options LimitRecursion(int maxRecursionDepth = 0)
+        {
+            _maxRecursionDepth = maxRecursionDepth;
+            return this;
+        }
+
         public Options Culture(CultureInfo cultureInfo)
         public Options Culture(CultureInfo cultureInfo)
         {
         {
             _culture = cultureInfo;
             _culture = cultureInfo;
@@ -132,6 +148,11 @@ namespace Jint
             return _maxStatements;
             return _maxStatements;
         }
         }
 
 
+        internal int GetMaxRecursionDepth()
+        {
+            return _maxRecursionDepth;
+        }
+
         internal TimeSpan GetTimeoutInterval()
         internal TimeSpan GetTimeoutInterval()
         {
         {
             return _timeoutInterval;
             return _timeoutInterval;

+ 26 - 0
Jint/Runtime/CallStack/CallStackElement.cs

@@ -0,0 +1,26 @@
+namespace Jint.Runtime
+{
+    using Jint.Native;
+    using Jint.Parser.Ast;
+
+    public class CallStackElement
+    {
+        private string _shortDescription;
+
+        public CallStackElement(CallExpression callExpression, JsValue function, string shortDescription)
+        {
+            _shortDescription = shortDescription;
+            CallExpression = callExpression;
+            Function = function;
+        }
+
+        public CallExpression CallExpression { get; private set; }
+
+        public JsValue Function { get; private set; }
+
+        public override string ToString()
+        {
+            return _shortDescription;
+        }
+    }
+}

+ 17 - 0
Jint/Runtime/CallStack/CallStackElementComparer.cs

@@ -0,0 +1,17 @@
+namespace Jint.Runtime.CallStack
+{
+    using System.Collections.Generic;
+
+    public class CallStackElementComparer: IEqualityComparer<CallStackElement>
+    {
+        public bool Equals(CallStackElement x, CallStackElement y)
+        {
+            return x.Function == y.Function;
+        }
+
+        public int GetHashCode(CallStackElement obj)
+        {
+            return obj.Function.GetHashCode();
+        }
+    }
+}

+ 53 - 0
Jint/Runtime/CallStack/JintCallStack.cs

@@ -0,0 +1,53 @@
+namespace Jint.Runtime.CallStack
+{
+    using System.Collections.Generic;
+    using System.Linq;
+
+    public class JintCallStack
+    {
+        private Stack<CallStackElement> _stack = new Stack<CallStackElement>();
+
+        private Dictionary<CallStackElement, int> _statistics =
+            new Dictionary<CallStackElement, int>(new CallStackElementComparer());
+
+        public int Push(CallStackElement item)
+        {
+            _stack.Push(item);
+            if (_statistics.ContainsKey(item))
+            {
+                return ++_statistics[item];
+            }
+            else
+            {
+                _statistics.Add(item, 0);
+                return 0;
+            }
+        }
+
+        public CallStackElement Pop()
+        {
+            var item = _stack.Pop();
+            if (_statistics[item] == 0)
+            {
+                _statistics.Remove(item);
+            }
+            else
+            {
+                _statistics[item]--;
+            }
+
+            return item;
+        }
+
+        public void Clear()
+        {
+            _stack.Clear();
+            _statistics.Clear();
+        }
+
+        public override string ToString()
+        {
+            return string.Join("->", _stack.Select(cse => cse.ToString()).Reverse());
+        }
+    }
+}

+ 23 - 2
Jint/Runtime/ExpressionIntepreter.cs

@@ -795,9 +795,23 @@ namespace Jint.Runtime
             var arguments = callExpression.Arguments.Select(EvaluateExpression).Select(_engine.GetValue).ToArray();
             var arguments = callExpression.Arguments.Select(EvaluateExpression).Select(_engine.GetValue).ToArray();
 
 
             var func = _engine.GetValue(callee);
             var func = _engine.GetValue(callee);
-
+            
             var r = callee as Reference;
             var r = callee as Reference;
 
 
+            var isRecursionHandled = _engine.Options.GetMaxRecursionDepth() >= 0;
+            if (isRecursionHandled)
+            {
+                var stackItem = new CallStackElement(callExpression, func, r != null ? r.GetReferencedName() : "anonymous function");
+
+                var recursionDepth = _engine.CallStack.Push(stackItem);
+
+                if (recursionDepth > _engine.Options.GetMaxRecursionDepth())
+                {
+                    _engine.CallStack.Pop();
+                    throw new RecursionDepthOverflowException(_engine.CallStack, stackItem.ToString());
+                }
+            }
+
             if (func == Undefined.Instance)
             if (func == Undefined.Instance)
             {
             {
                 throw new JavaScriptException(_engine.TypeError, r == null ? "" : string.Format("Object has no method '{0}'", (callee as Reference).GetReferencedName()));
                 throw new JavaScriptException(_engine.TypeError, r == null ? "" : string.Format("Object has no method '{0}'", (callee as Reference).GetReferencedName()));
@@ -837,7 +851,14 @@ namespace Jint.Runtime
                 return ((EvalFunctionInstance) callable).Call(thisObject, arguments, true);
                 return ((EvalFunctionInstance) callable).Call(thisObject, arguments, true);
             }
             }
             
             
-            return callable.Call(thisObject, arguments);
+            var result = callable.Call(thisObject, arguments);
+
+            if (isRecursionHandled)
+            {
+                _engine.CallStack.Pop();
+            }
+
+            return result;
         }
         }
 
 
         public JsValue EvaluateSequenceExpression(SequenceExpression sequenceExpression)
         public JsValue EvaluateSequenceExpression(SequenceExpression sequenceExpression)

+ 3 - 1
Jint/Runtime/Interop/MethodInfoFunctionInstance.cs

@@ -6,6 +6,8 @@ using Jint.Native.Function;
 
 
 namespace Jint.Runtime.Interop
 namespace Jint.Runtime.Interop
 {
 {
+    using System;
+
     public sealed class MethodInfoFunctionInstance : FunctionInstance
     public sealed class MethodInfoFunctionInstance : FunctionInstance
     {
     {
         private readonly MethodInfo[] _methods;
         private readonly MethodInfo[] _methods;
@@ -59,7 +61,7 @@ namespace Jint.Runtime.Interop
 
 
                     return result;
                     return result;
                 }
                 }
-                catch
+                catch 
                 {
                 {
                     // ignore method
                     // ignore method
                 }
                 }

+ 27 - 0
Jint/Runtime/RecursionDepthOverflowException.cs

@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+using Jint.Native;
+using Jint.Parser.Ast;
+
+namespace Jint.Runtime
+{
+    using Jint.Runtime.CallStack;
+
+    public class RecursionDepthOverflowException : Exception
+    {
+        public string CallChain { get; private set; }
+
+        public string CallExpressionReference { get; private set; }
+
+        public RecursionDepthOverflowException(JintCallStack currentStack, string currentExpressionReference)
+            : base("The recursion is forbidden by script host.")
+        {
+            CallExpressionReference = currentExpressionReference;
+
+            CallChain = currentStack.ToString();
+        }
+    }
+    
+}