Browse Source

interrogate: Support reverse binary operators as extension methods

Necessary for implementing #1048
rdb 4 years ago
parent
commit
98054d1bbd
1 changed files with 107 additions and 42 deletions
  1. 107 42
      dtool/src/interrogate/interfaceMakerPythonNative.cxx

+ 107 - 42
dtool/src/interrogate/interfaceMakerPythonNative.cxx

@@ -307,7 +307,9 @@ get_slotted_function_def(Object *obj, Function *func, FunctionRemap *remap,
   string method_name = func->_ifunc.get_name();
   bool is_unary_op = func->_ifunc.is_unary_op();
 
-  if (method_name == "operator +") {
+  if (method_name == "operator +" ||
+      method_name == "__add__" ||
+      method_name == "__radd__") {
     def._answer_location = "nb_add";
     def._wrapper_type = WT_binary_operator;
     return true;
@@ -319,13 +321,17 @@ get_slotted_function_def(Object *obj, Function *func, FunctionRemap *remap,
     return true;
   }
 
-  if (method_name == "operator -") {
+  if (method_name == "operator -" ||
+      method_name == "__sub__" ||
+      method_name == "__rsub__") {
     def._answer_location = "nb_subtract";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator *") {
+  if (method_name == "operator *" ||
+      method_name == "__mul__" ||
+      method_name == "__rmul__") {
     def._answer_location = "nb_multiply";
     def._wrapper_type = WT_binary_operator;
     return true;
@@ -337,37 +343,47 @@ get_slotted_function_def(Object *obj, Function *func, FunctionRemap *remap,
     return true;
   }
 
-  if (method_name == "__truediv__") {
+  if (method_name == "__truediv__" ||
+      method_name == "__rtruediv__") {
     def._answer_location = "nb_true_divide";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "__floordiv__") {
+  if (method_name == "__floordiv__" ||
+      method_name == "__rfloordiv__") {
     def._answer_location = "nb_floor_divide";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator %") {
+  if (method_name == "operator %" ||
+      method_name == "__mod__" ||
+      method_name == "__rmod__") {
     def._answer_location = "nb_remainder";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator <<") {
+  if (method_name == "operator <<" ||
+      method_name == "__lshift__" ||
+      method_name == "__rlshift__") {
     def._answer_location = "nb_lshift";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator >>") {
+  if (method_name == "operator >>" ||
+      method_name == "__rshift__" ||
+      method_name == "__rrshift__") {
     def._answer_location = "nb_rshift";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator ^") {
+  if (method_name == "operator ^" ||
+      method_name == "__xor__" ||
+      method_name == "__rxor__") {
     def._answer_location = "nb_xor";
     def._wrapper_type = WT_binary_operator;
     return true;
@@ -379,13 +395,17 @@ get_slotted_function_def(Object *obj, Function *func, FunctionRemap *remap,
     return true;
   }
 
-  if (method_name == "operator &") {
+  if (method_name == "operator &" ||
+      method_name == "__and__" ||
+      method_name == "__rand__") {
     def._answer_location = "nb_and";
     def._wrapper_type = WT_binary_operator;
     return true;
   }
 
-  if (method_name == "operator |") {
+  if (method_name == "operator |" ||
+      method_name == "__or__" ||
+      method_name == "__ror__") {
     def._answer_location = "nb_or";
     def._wrapper_type = WT_binary_operator;
     return true;
@@ -1896,16 +1916,8 @@ write_module_class(ostream &out, Object *obj) {
         break;
 
       case WT_one_param:
-      case WT_binary_operator:
-      case WT_inplace_binary_operator:
         // PyObject *func(PyObject *self, PyObject *one)
         {
-          int return_flags = RF_err_null;
-          if (rfi->second._wrapper_type == WT_inplace_binary_operator) {
-            return_flags |= RF_self;
-          } else {
-            return_flags |= RF_pyobject;
-          }
           bool all_nonconst = true;
           for (FunctionRemap *remap : def._remaps) {
             if (remap->_const_method) {
@@ -1918,20 +1930,7 @@ write_module_class(ostream &out, Object *obj) {
           out << "//////////////////\n";
           out << "static PyObject *" << def._wrapper_name << "(PyObject *self, PyObject *arg) {\n";
           out << "  " << cClassName << " *local_this = nullptr;\n";
-          if (rfi->second._wrapper_type != WT_one_param) {
-            // WT_binary_operator means we must return NotImplemented, instead
-            // of raising an exception, if the this pointer doesn't match.
-            // This is for things like __sub__, which Python likes to call on
-            // the wrong-type objects.
-            out << "  DTOOL_Call_ExtractThisPointerForType(self, &Dtool_" << ClassName << ", (void **)&local_this);\n";
-            if (all_nonconst) {
-              out << "  if (local_this == nullptr || DtoolInstance_IS_CONST(self)) {\n";
-            } else {
-              out << "  if (local_this == nullptr) {\n";
-            }
-            out << "    Py_INCREF(Py_NotImplemented);\n";
-            out << "    return Py_NotImplemented;\n";
-          } else if (all_nonconst) {
+          if (all_nonconst) {
             out << "  if (!Dtool_Call_ExtractThisPointer_NonConst(self, Dtool_"
                 << ClassName << ", (void **)&local_this, \"" << ClassName
                 << "." << methodNameFromCppName(fname, "", false) << "\")) {\n";
@@ -1944,19 +1943,85 @@ write_module_class(ostream &out, Object *obj) {
 
           string expected_params;
           write_function_forset(out, def._remaps, 1, 1, expected_params, 2, true, true,
-                                AT_single_arg, return_flags, false, !all_nonconst);
+                                AT_single_arg, RF_err_null | RF_pyobject, false, !all_nonconst);
+
+          out << "  if (!_PyErr_OCCURRED()) {\n";
+          out << "    return Dtool_Raise_BadArgumentsError(\n";
+          output_quoted(out, 6, expected_params);
+          out << ");\n";
+          out << "  }\n";
+          out << "  return nullptr;\n";
+          out << "}\n\n";
+        }
+        break;
 
-          if (rfi->second._wrapper_type != WT_one_param) {
-            out << "  Py_INCREF(Py_NotImplemented);\n";
-            out << "  return Py_NotImplemented;\n";
+      case WT_binary_operator:
+      case WT_inplace_binary_operator:
+        // PyObject *func(PyObject *self, PyObject *one)
+        {
+          int return_flags = RF_err_null;
+          if (rfi->second._wrapper_type == WT_inplace_binary_operator) {
+            return_flags |= RF_self;
           } else {
-            out << "  if (!_PyErr_OCCURRED()) {\n";
-            out << "    return Dtool_Raise_BadArgumentsError(\n";
-            output_quoted(out, 6, expected_params);
-            out << ");\n";
+            return_flags |= RF_pyobject;
+          }
+          bool forward_all_nonconst = true;
+          bool reverse_all_nonconst = true;
+          set<FunctionRemap *> forward_remaps;
+          set<FunctionRemap *> reverse_remaps;
+          for (FunctionRemap *remap : def._remaps) {
+            std::string fname = remap->_cppfunc->get_simple_name();
+            if (fname.compare(0, 3, "__r") == 0 && fname != "__rshift__") {
+              reverse_remaps.insert(remap);
+              if (remap->_const_method) {
+                reverse_all_nonconst = false;
+              }
+            } else {
+              forward_remaps.insert(remap);
+              if (remap->_const_method) {
+                forward_all_nonconst = false;
+              }
+            }
+          }
+          out << "//////////////////\n";
+          out << "// A wrapper function to satisfy Python's internal calling conventions.\n";
+          out << "// " << ClassName << " slot " << rfi->second._answer_location << " -> " << fname << "\n";
+          out << "//////////////////\n";
+          out << "static PyObject *" << def._wrapper_name << "(PyObject *self, PyObject *arg) {\n";
+          out << "  " << cClassName << " *local_this = nullptr;\n";
+          // WT_binary_operator means we must return NotImplemented, instead
+          // of raising an exception, if the this pointer doesn't match.
+          // This is for things like __sub__, which Python likes to call on
+          // the wrong-type objects.
+          if (!forward_remaps.empty()) {
+            out << "  DTOOL_Call_ExtractThisPointerForType(self, &Dtool_" << ClassName << ", (void **)&local_this);\n";
+            if (forward_all_nonconst) {
+              out << "  if (local_this != nullptr && !DtoolInstance_IS_CONST(self)) {\n";
+            } else {
+              out << "  if (local_this != nullptr) {\n";
+            }
+            string expected_params;
+            write_function_forset(out, forward_remaps, 1, 1, expected_params, 4, true, true,
+                                  AT_single_arg, return_flags, false, !forward_all_nonconst);
             out << "  }\n";
-            out << "  return nullptr;\n";
           }
+
+          if (!reverse_remaps.empty()) {
+            out << "  std::swap(self, arg);\n";
+            out << "  DTOOL_Call_ExtractThisPointerForType(self, &Dtool_" << ClassName << ", (void **)&local_this);\n";
+            if (reverse_all_nonconst) {
+              out << "  if (local_this != nullptr && !DtoolInstance_IS_CONST(self)) {\n";
+            } else {
+              out << "  if (local_this != nullptr) {\n";
+            }
+            string expected_params;
+            write_function_forset(out, reverse_remaps, 1, 1, expected_params, 4, true, true,
+                                  AT_single_arg, return_flags, false, !reverse_all_nonconst);
+            out << "  }\n";
+          }
+
+          out << "  Py_INCREF(Py_NotImplemented);\n";
+          out << "  return Py_NotImplemented;\n";
           out << "}\n\n";
         }
         break;