configurationengine/source/cone/public/rules.py
changeset 3 e7e0ae78773e
parent 0 2e8eeb919028
--- a/configurationengine/source/cone/public/rules.py	Fri Mar 12 08:30:17 2010 +0200
+++ b/configurationengine/source/cone/public/rules.py	Tue Aug 10 14:29:28 2010 +0300
@@ -17,10 +17,12 @@
 import operator as ops
 import logging
 import tokenize
-from token import ENDMARKER, NAME, ERRORTOKEN
+import re
+from token import ENDMARKER, NAME, ERRORTOKEN, OP
 import StringIO
 
-from cone.public import container
+from cone.public import container, exceptions, utils
+import types
 
 RELATIONS = {}
 
@@ -29,14 +31,20 @@
     caller = inspect.getouterframes(inspect.currentframe())[1][3]
     raise NotImplementedError(caller + ' needs to be implemented')
 
+REF_DOLLAR = '$'
+REF_START_BRACE = '{'
+REF_END_BRACE = '}'
+REF_REGEX = re.compile('(?P<ref>\${[\w\.\*]*})', re.UNICODE)
+
 def get_tokens(tokenstr):
     result = []
     tokens = []
     tokenstr = tokenstr.replace('\r', '')
     name_buffer = [] # Temp buffer for reading name tokens
     last_epos = None
+    
+    ref_start = False
     for toknum, tokval, spos, epos, _  in tokenize.generate_tokens(StringIO.StringIO(unicode(tokenstr)).readline):
-        #print "toknum: %r, tokval: %r, spos: %r, epos: %r" % (toknum, tokval, spos, epos)
         val = tokval.strip('\r\n\t ')
         
         if toknum == ENDMARKER and name_buffer:
@@ -46,15 +54,28 @@
         # since its value is empty)
         if val == '': continue
         
+        # Handle the references here, ${ref} is the format
+        if toknum in (OP, ERRORTOKEN) and\
+                tokval in (REF_DOLLAR, REF_START_BRACE, REF_END_BRACE) or\
+                ref_start:
+            if tokval == REF_DOLLAR:
+                ref_start = True
+            elif tokval == REF_END_BRACE:
+                ref_start = False
+            if name_buffer and spos[1] != last_epos[1]:
+                tokens.append(''.join(name_buffer))
+                name_buffer = []
+            name_buffer.append(tokval)
+            last_epos = epos
         # Put NAME, and ERRORTOKEN tokens through the temp
         # buffer
-        if toknum in (NAME, ERRORTOKEN):
+        elif toknum in (NAME, ERRORTOKEN):
             # If this and the previous token in the temp buffer are not adjacent,
             # they belong to separate tokens
             if name_buffer and spos[1] != last_epos[1]:
                 tokens.append(''.join(name_buffer))
                 name_buffer = []
-            
+
             name_buffer.append(val)
             last_epos = epos
         # Other tokens can just go directly to the token list
@@ -89,6 +110,8 @@
     relation_name = "RelationBase"
     def __init__(self, data, left, right):
         self.description = ""
+        self.ref = None
+        self.lineno = None
         self.data = data or container.DataContainer()
         self.left = left
         self.right = right
@@ -98,6 +121,9 @@
         @return: A string presentation of the relation object
         """
         return "%s %s %s" % (self.left,self.relation_name,self.right)
+    
+    def __repr__(self):
+        return "%s(ref=%r, lineno=%r)" % (self.__class__.__name__, self.ref, self.lineno)
 
     def get_name(self):
         """
@@ -111,7 +137,7 @@
         """
         return self.description
         
-    def execute(self):
+    def execute(self, context=None):
         """
         Execute the relation object.
         """
@@ -161,12 +187,12 @@
 class RelationContainerImpl(RelationContainer):
     """ Base implementation for RelationContainer to use in ConE rules
     """
-    def execute(self):
+    def execute(self, context=None):
         ret = True
         i = 0
         for relation in self:
             i += 1
-            r = relation.execute()
+            r = relation.execute(context)
             ret = ret and r
         return ret
 
@@ -225,8 +251,9 @@
         super(BaseRelation, self).__init__(data, left, right)
         self.interpreter = ASTInterpreter(context=self.context)
 
-    def execute(self):
+    def execute(self, context=None):
         """
+        @param context: The context for execution can be given as a parameter. 
         @return Returns error dictionary
 
         In the client code proper way to check if the rule applies:
@@ -236,7 +263,7 @@
         """
         # logger.debug("Interpreter context %s" % self.interpreter.context)
         self.interpreter.create_ast('%s %s %s' % (self.left, self.KEY, self.right))
-        ret = self.interpreter.eval()
+        ret = self.interpreter.eval(context, relation=self)
         return ret
 
     def get_keys(self):
@@ -255,8 +282,44 @@
         return bool(self.interpreter.errors)
 
     def get_refs(self):
-        return (ASTInterpreter.extract_refs(self.left), ASTInterpreter.extract_refs(self.right))
+        """
+        Get a list of left side references and right side references.
+        @return: left refs
+        """
+        try:
+            refs = []
+            tempast = ASTInterpreter()
+            tempast.create_ast("%s" % self.left)
+            for exp in tempast.expression_list:
+                refs += exp.get_refs()
+        except Exception, e:
+            utils.log_exception(logging.getLogger('cone.rules'), "Exception in get_refs() of relation %r: %s" % (self, e))
+            return []
+        return refs
 
+    def get_set_refs(self):
+        """
+        Get a list of references that could get altered by set expression in this rule. 
+        This list is empty if the relation does not have any set expressions.
+        @return: a list of references.
+        """
+        
+        return [exp.left.get_ref() for exp in self.get_set_expressions()]
+
+    def get_expressions(self):
+        if not self.interpreter.expression_list:
+            self.interpreter.create_ast('%s %s %s' % (self.left, self.KEY, self.right))
+        return self.interpreter.expression_list
+    
+    def get_set_expressions(self):
+        setelems = []
+        if not self.interpreter.expression_list:
+            self.interpreter.create_ast('%s %s %s' % (self.left, self.KEY, self.right))
+        for elem in self.interpreter.expression_list:
+            if isinstance(elem, SetExpression):
+                setelems.append(elem)
+        return setelems
+        
     def _eval_rside_value(self, value): abstract()
     def _compare_value(self, value): abstract()
 
@@ -279,7 +342,7 @@
         return self.interpreter.errors
 
     def expand_rule_elements(self, rule):
-        """ Expans rule elements base on the reference.
+        """ Expands rule elements base on the reference.
         Context is used for fetching the child elements for parent references
         which uses asterisk identifier for selecting all child features: 
         'parent_feature.*' -> 'child_fea_1 and child_fea_2'.
@@ -298,7 +361,11 @@
                 else:
                     expanded_rule = expanded_element.rstrip()
             elif token.lower() in OPERATORS:
-                expanded_rule += ' %s ' % token
+                operator_class = OPERATORS[token]
+                if operator_class.PARAM_COUNT == 2:
+                    expanded_rule += ' %s ' % token
+                else:
+                    expanded_rule += '%s ' % token
             else:
                 if expanded_rule:
                     expanded_rule += '%s'% token
@@ -337,13 +404,32 @@
     def eval(self, ast, expression, value):
         pass
 
-    def get_keys(self, refs):
-        return ASTInterpreter.extract_refs(refs)
+    def set(self, expression, value):
+        """
+        set a element described with expression to value
+        @param expression: the expression refering to a element 
+        @param value:  the value to set 
+        @raise exception: when the setting value to expression fails.  
+        """
+        pass
 
     def get_children_for_reference(self, reference):
         # implement ConE specific children expansion
         pass
 
+    def convert_value(self, value):
+        if value in ('True', 'true', '1'):
+            return True
+        elif value in ('False', 'false', '0'):
+            return False
+        elif value in ('None',):
+            return None
+        else:
+            try:
+                return int(value)
+            except:
+                return value
+
     def handle_terminal(self, expression):
         try:
             return int(expression)
@@ -374,10 +460,18 @@
     def get_title(self):
         return self.KEY
 
-    def eval(self, context): pass
+    def is_terminal(self):
+        return False
+
+    def eval(self, context, **kwargs): pass
+
+    def get_refs(self): return []
 
 class OneParamExpression(Expression):
     PARAM_COUNT = 1
+    # OP that return itself 
+    OP = lambda _, x : x
+
     def __init__(self, ast, expression):
         super(OneParamExpression, self).__init__(ast)
         self.expression = expression
@@ -385,8 +479,11 @@
     def __unicode__(self):
         return u'%s %s' % (self.KEY, self.expression)
 
-    def eval(self, context):
-        self.value = self.OP(self.expression.eval(context))
+    def __str__(self):
+        return '%s %s' % (self.KEY, self.expression)
+
+    def eval(self, context, **kwargs):
+        self.value = self.OP(self.expression.eval(context)) 
         context.eval(self.ast, self, self.value)
         return self.value
 
@@ -403,38 +500,128 @@
     def __unicode__(self):
         return u'%s %s %s' % (self.left, self.KEY, self.right)
 
-    def eval(self, context):
+    def __str__(self):
+        return '%s %s %s' % (self.left, self.KEY, self.right)
+
+    def eval(self, context, **kwargs):
         self.value = self.OP(self.left.eval(context), self.right.eval(context))
         context.eval(self.ast, self, self.value)
         return self.value
 
 class TwoOperatorBooleanExpression(TwoOperatorExpression):
-    def eval(self, context):
-        self.value = self.OP(bool(self.left.eval(context)), bool(self.right.eval(context)))
+    def eval(self, context, **kwargs):
+        self.value = self.OP(bool(self.left.eval(context, **kwargs)), bool(self.right.eval(context, **kwargs)))
         context.eval(self.ast, self, self.value)
         return self.value         
 
-class TerminalExpression(Expression):
-    KEY = 'terminal'
+class ReferenceTerminal(Expression):
+    PARAM_COUNT = 0
+    KEY = 'reference'
 
     def __init__(self, ast, expression):
-        super(TerminalExpression, self).__init__(ast)
+        super(ReferenceTerminal, self).__init__(ast)
+        if not ASTInterpreter.is_ref(unicode(expression)):
+            expression = ASTInterpreter.create_ref(expression)
         self.expression = expression
 
-    def eval(self, context):
+    def is_terminal(self):
+        return True
+
+    def eval(self, context, **kwargs):
         """ Use context to eval the value
-        Expression on TerminalExpression is feature reference or value
+        Expression on ReferenceTerminal is feature reference or value
         context should handle the reference conversion to correct value
         """
-        self.value = context.handle_terminal(self.expression)
+        self.value = context.handle_terminal(ASTInterpreter.clean_ref(self.expression))
+        return self.value
+    
+    def get_ref(self):
+        """
+        @return: The setting reference, e.g. 'MyFeature.MySetting'
+        """
+        return ASTInterpreter.clean_ref(self.expression)
+
+    def get_refs(self):
+        """
+        """
+        return [u'%s' % self.get_ref()]
+
+    def __unicode__(self):
+        return self.expression
+    
+    def __str__(self):
+        return "(%s => %s)" % (self.expression, self.value)
+
+    def __repr__(self):
+        return self.expression
+
+class ValueTerminal(Expression):
+    PARAM_COUNT = 0
+    KEY = 'value_terminal'
+
+    def __init__(self, ast, expression):
+        super(ValueTerminal, self).__init__(ast)
+        self.expression = expression
+
+    def is_terminal(self):
+        return True
+
+    def eval(self, context, **kwargs):
+        self.value = context.convert_value(self.expression)
         return self.value
 
     def __unicode__(self):
         return self.expression
     
-    def __repr__(self):
+    def __str__(self):
         return self.expression
 
+class TypeCoercionError(exceptions.ConeException):
+    pass
+
+class AutoValueTerminal(Expression):
+    PARAM_COUNT = 0
+    KEY = 'autovalue_terminal'
+
+    def __init__(self, ast, expression):
+        super(AutoValueTerminal, self).__init__(ast)
+        self.expression = expression
+
+    def is_terminal(self):
+        return True
+
+    def eval(self, context, **kwargs):
+        type = kwargs.get('type', None)
+        
+        if self.expression in ("None", None):
+            self.value = None
+        elif type in (types.IntType, types.FloatType):
+            try:
+                self.value = type(self.expression)
+            except ValueError:
+                raise TypeCoercionError("Cannot coerce %r to %s" % (self.expression, type))
+        elif type == types.BooleanType:
+            if self.expression in ('True', 'true', True):
+                self.value = True
+            elif self.expression in ('False', 'false', False):
+                self.value = False
+            else:
+                raise TypeCoercionError("Cannot coerce %r to %s" % (self.expression, type))
+        elif type in types.StringTypes:
+            self.value = unicode(self.expression)
+        elif type == types.ListType:
+            self.value = list(self.expression)
+        else:
+            raise TypeCoercionError("Cannot coerce %r to %s" % (self.expression, type))
+        return self.value
+    
+    def __unicode__(self):
+        return self.expression
+    
+    def __str__(self):
+        return self.expression
+
+
 class NegExpression(OneParamExpression):
     PRECEDENCE = PRECEDENCES['PREFIX_OPERATORS']
     KEY= '-'
@@ -470,6 +657,11 @@
     KEY = '=='
     OP = ops.eq
 
+    def eval(self, context, **kwargs):
+        self.value = self.OP(self.left.eval(context), self.right.eval(context))
+        context.eval(self.ast, self, self.value)
+        return self.value
+
 class NotEqualExpression(TwoOperatorExpression):
     PRECEDENCE = PRECEDENCES['COMPARISON_OPERATORS']
     KEY = '!='
@@ -496,6 +688,59 @@
     OP = ops.ge
 
 
+def handle_multiply(self, left, right):
+    return left * right
+
+class MultiplyExpression(TwoOperatorExpression):
+    expression = "multiply_operation"
+    PRECEDENCE = PRECEDENCES['MULDIV_OPERATORS']
+    KEY= '*'
+    OP = handle_multiply
+
+def handle_divide(self, left, right):
+    return left / right
+
+class DivideExpression(TwoOperatorExpression):
+    expression = "divide_operation"
+    PRECEDENCE = PRECEDENCES['MULDIV_OPERATORS']
+    KEY= '/'
+    OP = handle_divide
+
+def handle_plus(self, left, right):
+    return left + right
+
+class PlusExpression(TwoOperatorExpression):
+    expression = "plus_operation"
+    PRECEDENCE = PRECEDENCES['ADDSUB_OPERATORS']
+    KEY= '+'
+    OP = handle_plus
+
+def handle_minus(self, left, right):
+    return left - right
+
+class MinusExpression(TwoOperatorExpression):
+    expression = "minus_operation"
+    PRECEDENCE = PRECEDENCES['ADDSUB_OPERATORS']
+    KEY= '-'
+    OP = handle_minus
+
+def handle_set(self, left, right):
+    left.set_value(right)
+
+class SetExpression(TwoOperatorExpression):
+    PRECEDENCE = PRECEDENCES['SET_OPERATORS']
+    KEY= '='
+    OP = handle_set
+
+    def eval(self, context, **kwargs):
+        if not isinstance(self.left, ReferenceTerminal):
+            raise RuntimeError("Can only set the value of a setting, '%s' is not a setting reference. Did you forget to use ${}?" % self.left.expression)
+        
+        value = self.right.eval(context, **kwargs)
+        context.set(self.left.get_ref(), value, **kwargs)
+        return True
+
+
 def handle_require(expression, left, right):
     if left and right:
         return True
@@ -508,13 +753,12 @@
     KEY = 'requires'
     OP = handle_require
 
-    def eval(self, context):
+    def eval(self, context, **kwargs):
         super(RequireExpression, self).eval(context)
         if not self.value:
             left_keys = []
             for ref in self.ast.extract_refs(unicode(self.left)):
-                for key in context.get_keys(ref):
-                    left_keys.append(key)
+                left_keys.append(ref)
 
             for key in left_keys:
                 self.ast.add_error(key, { 'error_string' : 'REQUIRES right side value is "False"',
@@ -535,13 +779,12 @@
     KEY = 'excludes'
     OP = handle_exclude
 
-    def eval(self, context):
+    def eval(self, context, **kwargs):
         super(ExcludeExpression, self).eval(context)
         if not self.value:
             left_keys = []
             for ref in self.ast.extract_refs(unicode(self.left)):
-                for key in context.get_keys(ref):
-                    left_keys.append(key)
+                left_keys.append(ref)
                     
             for key in left_keys:
                 self.ast.add_error(key, { 'error_string' : 'EXCLUDE right side value is "True"',
@@ -568,11 +811,38 @@
     A simple condition object that can refer to a model object and evaluate if the value matches  
     """
     def __init__(self, left, right):
-        lterm = TerminalExpression(None, left)
-        rterm = TerminalExpression(None, right)
+        if isinstance(left, basestring) and ASTInterpreter.is_ref(left):
+            lterm = ReferenceTerminal(None, left)
+        else:
+            lterm = ValueTerminal(None, left)
+        if isinstance(right, basestring) and ASTInterpreter.is_ref(right):
+            rterm = ReferenceTerminal(None, right)
+        else:
+            rterm = AutoValueTerminal(None, right)
         EqualExpression.__init__(self, None, lterm, rterm)
 
-
+    def eval(self, context, **kwargs):
+        left_value = self.left.eval(context)
+        try:
+            right_value = self.right.eval(context, type=type(left_value))
+        except TypeCoercionError:
+            # If type coercion fails, the result is always False
+            self.value = False
+            return self.value
+        
+        # Type coercion successful, perform value comparison
+        self.value = self.OP(left_value, right_value)
+        context.eval(self.ast, self, self.value)
+        return self.value
+    
+    def get_refs(self):
+        result = []
+        if isinstance(self.left, ReferenceTerminal):
+            result.append(self.left.get_ref())
+        if isinstance(self.right, ReferenceTerminal):
+            result.append(self.right.get_ref())
+        return result
+    
 # in format KEY : OPERATOR CLASS
 OPERATORS = {
     'and' : AndExpression,
@@ -590,7 +860,11 @@
     '>=' : GreaterThanEqualExpression,
     'requires' : RequireExpression,
     'excludes' : ExcludeExpression,
-    '-' : NegExpression
+    '-' : MinusExpression,
+    '+' : PlusExpression,
+    '*' : MultiplyExpression,
+    '/' : DivideExpression,
+    '=' : SetExpression
     }
 
 def add_operator(key, operator_class=None, baseclass=RequireExpression):
@@ -613,6 +887,29 @@
 
 class ParseException(Exception): pass
 
+def is_str_literal(value):
+    """
+    return true if the value is a string literal. A string that begins and ends with single or douple quotes.
+    @param value: the value to investigate
+    @return: Boolean
+    """
+    if  isinstance(value, (str, unicode)):
+        if re.match("[\"\'].*[\"\']", value):
+            return True
+    return False
+
+def get_str_literal(value):
+    """
+    return the string literal value
+    @param value: the value to convert
+    @return: string or unicode based on the input value
+    """
+    if  isinstance(value, (str, unicode)):
+        m =  re.match("[\"\'](.*)[\"\']", value)
+        if m:
+            return m.group(1)
+    return None
+
 class ASTInterpreter(object):
     def __init__(self, infix_expression=None, context=None):
         """ Takes infix expression as string """
@@ -629,6 +926,7 @@
         self.errors = {}
         self.postfix_array = []
         self.parse_tree = []
+        self.expression_list = []
         self.expression = infix_expression
 
     def __unicode__(self):
@@ -652,11 +950,11 @@
 
     def _infix_to_postfix(self):
         """
-        Shunting yard algorithm used to convert infix presentation to postfix.
+        Shunting yard algorithm used to convert infix presentation to postxfix.
         """
         if not self.expression:
             raise ParseException('Expression is None')
-        tokens = get_tokens(self.expression) # [token for token in self.expression.split()]
+        tokens = get_tokens(self.expression)
         stack = []
         # logger.debug('TOKENS: %s' % tokens)
         for token in tokens:
@@ -721,6 +1019,7 @@
     def _create_parse_tree(self):
         self.parse_tree = []
         for token in self.postfix_array:
+            is_ref = ASTInterpreter.is_ref(token)
             if token in OPERATORS:
                 # logger.debug('OP: %s' % (token))
                 expression_class = OPERATORS[token]
@@ -735,9 +1034,15 @@
 
                 # logger.debug('The operation: %s' % expression)
                 self.parse_tree.append(expression)
+                self.expression_list.append(expression)
+            elif not is_ref:
+                expression = ValueTerminal(self, token)
+                self.parse_tree.append(expression)
+                self.expression_list.append(expression)
             else:
-                expression = TerminalExpression(self, token)
+                expression = ReferenceTerminal(self, token)
                 self.parse_tree.append(expression)
+                self.expression_list.append(expression)
 
         #logger.debug('THE STACK: %s' % self.parse_tree)
         #for s in self.parse_tree:
@@ -745,22 +1050,41 @@
 
         return self.parse_tree
 
-    def eval(self):
+    def eval(self, context=None, **kwargs):
         """ Evals the AST
         If empty expression is given, None is returned
         """
         for expression in self.parse_tree:
-            self.value = expression.eval(self.context)
+            self.value = expression.eval(context or self.context, **kwargs)
         return self.value
 
     @staticmethod
     def extract_refs(expression):
         tokens = get_tokens(expression)
-        refs = []
-        for token in tokens:
-            if not token.lower() in OPERATORS and token != LEFT_PARENTHESIS and token != RIGHT_PARENTHESIS:
-                refs.append(token.strip('%s%s' % (LEFT_PARENTHESIS, RIGHT_PARENTHESIS)))
-        return refs
+        return [ASTInterpreter.clean_ref(t) for t in tokens if t.lower() not in OPERATORS and\
+                    t not in (LEFT_PARENTHESIS, RIGHT_PARENTHESIS) and\
+                    ASTInterpreter.is_ref(t)]
+
+    @staticmethod
+    def extract_non_operators(expression):
+        tokens = get_tokens(expression)
+        return [ASTInterpreter.clean_ref(t) for t in tokens if t.lower() not in OPERATORS and\
+                    t not in (LEFT_PARENTHESIS, RIGHT_PARENTHESIS)]
+
+    @staticmethod
+    def clean_ref(ref):
+        return ref.replace('$', '').replace('{', '').replace('}', '')
+
+    @staticmethod
+    def create_ref(ref):
+        return '${%s}' % ref
+
+    @staticmethod
+    def is_ref(val):
+        mo = REF_REGEX.match(val)
+        if mo and len(mo.groups()) > 0 and mo.group() == val:
+            return True
+        return False
 
 ##################################################################
 # Create and configure the main level logger