瀏覽代碼

feat: Support subtypes for `ExcalidrawTextElement`

Daniel J. Geiger 1 年之前
父節點
當前提交
7958b7144a

+ 6 - 6
src/actions/actionBoundText.tsx

@@ -10,7 +10,7 @@ import {
   computeBoundTextPosition,
   computeContainerDimensionForBoundText,
   getBoundTextElement,
-  measureText,
+  measureTextElement,
   redrawTextBoundingBox,
 } from "../element/textElement";
 import {
@@ -31,7 +31,6 @@ import {
 } from "../element/types";
 import { AppState } from "../types";
 import { Mutable } from "../utility-types";
-import { getFontString } from "../utils";
 import { register } from "./register";
 
 export const actionUnbindText = register({
@@ -48,10 +47,11 @@ export const actionUnbindText = register({
     selectedElements.forEach((element) => {
       const boundTextElement = getBoundTextElement(element);
       if (boundTextElement) {
-        const { width, height, baseline } = measureText(
-          boundTextElement.originalText,
-          getFontString(boundTextElement),
-          boundTextElement.lineHeight,
+        const { width, height, baseline } = measureTextElement(
+          boundTextElement,
+          {
+            text: boundTextElement.originalText,
+          },
         );
         const originalContainerHeight = getOriginalContainerHeightFromCache(
           element.id,

+ 8 - 8
src/data/restore.ts

@@ -34,13 +34,13 @@ import {
 import { getDefaultAppState } from "../appState";
 import { LinearElementEditor } from "../element/linearElementEditor";
 import { bumpVersion } from "../element/mutateElement";
-import { getFontString, getUpdatedTimestamp, updateActiveTool } from "../utils";
+import { getUpdatedTimestamp, updateActiveTool } from "../utils";
 import { arrayToMap } from "../utils";
 import { MarkOptional, Mutable } from "../utility-types";
 import {
   detectLineHeight,
   getDefaultLineHeight,
-  measureBaseline,
+  measureTextElement,
 } from "../element/textElement";
 import { normalizeLink } from "./url";
 
@@ -93,7 +93,8 @@ const repairBinding = (binding: PointBinding | null) => {
 };
 
 const restoreElementWithProperties = <
-  T extends Required<Omit<ExcalidrawElement, "customData">> & {
+  T extends Required<Omit<ExcalidrawElement, "subtype" | "customData">> & {
+    subtype?: ExcalidrawElement["subtype"];
     customData?: ExcalidrawElement["customData"];
     /** @deprecated */
     boundElementIds?: readonly ExcalidrawElement["id"][];
@@ -159,6 +160,9 @@ const restoreElementWithProperties = <
     locked: element.locked ?? false,
   };
 
+  if ("subtype" in element) {
+    base.subtype = element.subtype;
+  }
   if ("customData" in element) {
     base.customData = element.customData;
   }
@@ -204,11 +208,7 @@ const restoreElement = (
           : // no element height likely means programmatic use, so default
             // to a fixed line height
             getDefaultLineHeight(element.fontFamily));
-      const baseline = measureBaseline(
-        element.text,
-        getFontString(element),
-        lineHeight,
-      );
+      const baseline = measureTextElement(element, { text }).baseline;
       element = restoreElementWithProperties(element, {
         fontSize,
         fontFamily,

+ 24 - 3
src/element/mutateElement.ts

@@ -6,12 +6,21 @@ import { Point } from "../types";
 import { getUpdatedTimestamp } from "../utils";
 import { Mutable } from "../utility-types";
 import { ShapeCache } from "../scene/ShapeCache";
+import { getSubtypeMethods } from "./subtypes";
 
 type ElementUpdate<TElement extends ExcalidrawElement> = Omit<
   Partial<TElement>,
   "id" | "version" | "versionNonce"
 >;
 
+const cleanUpdates = <TElement extends Mutable<ExcalidrawElement>>(
+  element: TElement,
+  updates: ElementUpdate<TElement>,
+): ElementUpdate<TElement> => {
+  const map = getSubtypeMethods(element.subtype);
+  return map?.clean ? (map.clean(updates) as typeof updates) : updates;
+};
+
 // This function tracks updates of text elements for the purposes for collaboration.
 // The version is used to compare updates when more than one user is working in
 // the same drawing. Note: this will trigger the component to update. Make sure you
@@ -22,6 +31,8 @@ export const mutateElement = <TElement extends Mutable<ExcalidrawElement>>(
   informMutation = true,
 ): TElement => {
   let didChange = false;
+  let increment = false;
+  const oldUpdates = cleanUpdates(element, updates);
 
   // casting to any because can't use `in` operator
   // (see https://github.com/microsoft/TypeScript/issues/21732)
@@ -70,6 +81,7 @@ export const mutateElement = <TElement extends Mutable<ExcalidrawElement>>(
             }
           }
           if (!didChangePoints) {
+            key in oldUpdates && (increment = true);
             continue;
           }
         }
@@ -77,6 +89,7 @@ export const mutateElement = <TElement extends Mutable<ExcalidrawElement>>(
 
       (element as any)[key] = value;
       didChange = true;
+      key in oldUpdates && (increment = true);
     }
   }
   if (!didChange) {
@@ -92,9 +105,11 @@ export const mutateElement = <TElement extends Mutable<ExcalidrawElement>>(
     ShapeCache.delete(element);
   }
 
-  element.version++;
-  element.versionNonce = randomInteger();
-  element.updated = getUpdatedTimestamp();
+  if (increment) {
+    element.version++;
+    element.versionNonce = randomInteger();
+    element.updated = getUpdatedTimestamp();
+  }
 
   if (informMutation) {
     Scene.getScene(element)?.informMutation();
@@ -108,6 +123,8 @@ export const newElementWith = <TElement extends ExcalidrawElement>(
   updates: ElementUpdate<TElement>,
 ): TElement => {
   let didChange = false;
+  let increment = false;
+  const oldUpdates = cleanUpdates(element, updates);
   for (const key in updates) {
     const value = (updates as any)[key];
     if (typeof value !== "undefined") {
@@ -119,6 +136,7 @@ export const newElementWith = <TElement extends ExcalidrawElement>(
         continue;
       }
       didChange = true;
+      key in oldUpdates && (increment = true);
     }
   }
 
@@ -126,6 +144,9 @@ export const newElementWith = <TElement extends ExcalidrawElement>(
     return element;
   }
 
+  if (!increment) {
+    return { ...element, ...updates };
+  }
   return {
     ...element,
     ...updates,

+ 39 - 24
src/element/newElement.ts

@@ -15,12 +15,7 @@ import {
   ExcalidrawFrameElement,
   ExcalidrawEmbeddableElement,
 } from "../element/types";
-import {
-  arrayToMap,
-  getFontString,
-  getUpdatedTimestamp,
-  isTestEnv,
-} from "../utils";
+import { arrayToMap, getUpdatedTimestamp, isTestEnv } from "../utils";
 import { randomInteger, randomId } from "../random";
 import { bumpVersion, newElementWith } from "./mutateElement";
 import { getNewGroupIdsForDuplication } from "../groups";
@@ -30,9 +25,9 @@ import { adjustXYWithRotation } from "../math";
 import { getResizedElementAbsoluteCoords } from "./bounds";
 import {
   getContainerElement,
-  measureText,
+  measureTextElement,
   normalizeText,
-  wrapText,
+  wrapTextElement,
   getBoundTextMaxWidth,
   getDefaultLineHeight,
 } from "./textElement";
@@ -45,6 +40,21 @@ import {
   VERTICAL_ALIGN,
 } from "../constants";
 import { MarkOptional, Merge, Mutable } from "../utility-types";
+import { getSubtypeMethods } from "./subtypes";
+
+export const maybeGetSubtypeProps = (obj: {
+  subtype?: ExcalidrawElement["subtype"];
+  customData?: ExcalidrawElement["customData"];
+}) => {
+  const data: typeof obj = {};
+  if ("subtype" in obj && obj.subtype !== undefined) {
+    data.subtype = obj.subtype;
+  }
+  if ("customData" in obj && obj.customData !== undefined) {
+    data.customData = obj.customData;
+  }
+  return data as typeof obj;
+};
 
 export type ElementConstructorOpts = MarkOptional<
   Omit<ExcalidrawGenericElement, "id" | "type" | "isDeleted" | "updated">,
@@ -58,6 +68,8 @@ export type ElementConstructorOpts = MarkOptional<
   | "version"
   | "versionNonce"
   | "link"
+  | "subtype"
+  | "customData"
   | "strokeStyle"
   | "fillStyle"
   | "strokeColor"
@@ -93,8 +105,10 @@ const _newElementBase = <T extends ExcalidrawElement>(
     ...rest
   }: ElementConstructorOpts & Omit<Partial<ExcalidrawGenericElement>, "type">,
 ) => {
+  const { subtype, customData } = rest;
   // assign type to guard against excess properties
   const element: Merge<ExcalidrawGenericElement, { type: T["type"] }> = {
+    ...maybeGetSubtypeProps({ subtype, customData }),
     id: rest.id || randomId(),
     type,
     x,
@@ -128,8 +142,11 @@ export const newElement = (
   opts: {
     type: ExcalidrawGenericElement["type"];
   } & ElementConstructorOpts,
-): NonDeleted<ExcalidrawGenericElement> =>
-  _newElementBase<ExcalidrawGenericElement>(opts.type, opts);
+): NonDeleted<ExcalidrawGenericElement> => {
+  const map = getSubtypeMethods(opts?.subtype);
+  map?.clean && map.clean(opts);
+  return _newElementBase<ExcalidrawGenericElement>(opts.type, opts);
+};
 
 export const newEmbeddableElement = (
   opts: {
@@ -196,10 +213,12 @@ export const newTextElement = (
   const fontSize = opts.fontSize || DEFAULT_FONT_SIZE;
   const lineHeight = opts.lineHeight || getDefaultLineHeight(fontFamily);
   const text = normalizeText(opts.text);
-  const metrics = measureText(
-    text,
-    getFontString({ fontFamily, fontSize }),
-    lineHeight,
+  const metrics = measureTextElement(
+    { ...opts, fontSize, fontFamily, lineHeight },
+    {
+      text,
+      customData: opts.customData,
+    },
   );
   const textAlign = opts.textAlign || DEFAULT_TEXT_ALIGN;
   const verticalAlign = opts.verticalAlign || DEFAULT_VERTICAL_ALIGN;
@@ -244,7 +263,9 @@ const getAdjustedDimensions = (
     width: nextWidth,
     height: nextHeight,
     baseline: nextBaseline,
-  } = measureText(nextText, getFontString(element), element.lineHeight);
+  } = measureTextElement(element, {
+    text: nextText,
+  });
   const { textAlign, verticalAlign } = element;
   let x: number;
   let y: number;
@@ -253,11 +274,7 @@ const getAdjustedDimensions = (
     verticalAlign === VERTICAL_ALIGN.MIDDLE &&
     !element.containerId
   ) {
-    const prevMetrics = measureText(
-      element.text,
-      getFontString(element),
-      element.lineHeight,
-    );
+    const prevMetrics = measureTextElement(element);
     const offsets = getTextElementPositionOffsets(element, {
       width: nextWidth - prevMetrics.width,
       height: nextHeight - prevMetrics.height,
@@ -313,11 +330,9 @@ export const refreshTextDimensions = (
   }
   const container = getContainerElement(textElement);
   if (container) {
-    text = wrapText(
+    text = wrapTextElement(textElement, getBoundTextMaxWidth(container), {
       text,
-      getFontString(textElement),
-      getBoundTextMaxWidth(container),
-    );
+    });
   }
   const dimensions = getAdjustedDimensions(textElement, text);
   return { text, ...dimensions };

+ 2 - 6
src/element/resizeElements.ts

@@ -51,7 +51,7 @@ import {
   handleBindTextResize,
   getBoundTextMaxWidth,
   getApproxMinLineHeight,
-  measureText,
+  measureTextElement,
   getBoundTextMaxHeight,
 } from "./textElement";
 import { LinearElementEditor } from "./linearElementEditor";
@@ -224,11 +224,7 @@ const measureFontSizeFromWidth = (
   if (nextFontSize < MIN_FONT_SIZE) {
     return null;
   }
-  const metrics = measureText(
-    element.text,
-    getFontString({ fontSize: nextFontSize, fontFamily: element.fontFamily }),
-    element.lineHeight,
-  );
+  const metrics = measureTextElement(element, { fontSize: nextFontSize });
   return {
     size: nextFontSize,
     baseline: metrics.baseline + (nextHeight - metrics.height),

+ 222 - 0
src/element/subtypes/index.ts

@@ -0,0 +1,222 @@
+import { ExcalidrawElement, ExcalidrawTextElement, NonDeleted } from "../types";
+import { getNonDeletedElements } from "../";
+
+import { isTextElement } from "../typeChecks";
+import { getContainerElement, redrawTextBoundingBox } from "../textElement";
+import { ShapeCache } from "../../scene/ShapeCache";
+import Scene from "../../scene/Scene";
+
+// Use "let" instead of "const" so we can dynamically add subtypes
+let subtypeNames: readonly Subtype[] = [];
+let parentTypeMap: readonly {
+  subtype: Subtype;
+  parentType: ExcalidrawElement["type"];
+}[] = [];
+
+export type SubtypeRecord = Readonly<{
+  subtype: Subtype;
+  parents: readonly ExcalidrawElement["type"][];
+}>;
+
+// Subtype Names
+export type Subtype = Required<ExcalidrawElement>["subtype"];
+export const getSubtypeNames = (): readonly Subtype[] => {
+  return subtypeNames;
+};
+
+// Subtype Methods
+export type SubtypeMethods = {
+  clean: (
+    updates: Omit<
+      Partial<ExcalidrawElement>,
+      "id" | "version" | "versionNonce"
+    >,
+  ) => Omit<Partial<ExcalidrawElement>, "id" | "version" | "versionNonce">;
+  ensureLoaded: (callback?: () => void) => Promise<void>;
+  getEditorStyle: (element: ExcalidrawTextElement) => Record<string, any>;
+  measureText: (
+    element: Pick<
+      ExcalidrawTextElement,
+      | "subtype"
+      | "customData"
+      | "fontSize"
+      | "fontFamily"
+      | "text"
+      | "lineHeight"
+    >,
+    next?: {
+      fontSize?: number;
+      text?: string;
+      customData?: ExcalidrawElement["customData"];
+    },
+  ) => { width: number; height: number; baseline: number };
+  render: (
+    element: NonDeleted<ExcalidrawElement>,
+    context: CanvasRenderingContext2D,
+  ) => void;
+  renderSvg: (
+    svgRoot: SVGElement,
+    root: SVGElement,
+    element: NonDeleted<ExcalidrawElement>,
+    opt?: { offsetX?: number; offsetY?: number },
+  ) => void;
+  wrapText: (
+    element: Pick<
+      ExcalidrawTextElement,
+      | "subtype"
+      | "customData"
+      | "fontSize"
+      | "fontFamily"
+      | "originalText"
+      | "lineHeight"
+    >,
+    containerWidth: number,
+    next?: {
+      fontSize?: number;
+      text?: string;
+      customData?: ExcalidrawElement["customData"];
+    },
+  ) => string;
+};
+
+type MethodMap = { subtype: Subtype; methods: Partial<SubtypeMethods> };
+const methodMaps = [] as Array<MethodMap>;
+
+// Use `getSubtypeMethods` to call subtype-specialized methods, like `render`.
+export const getSubtypeMethods = (
+  subtype: Subtype | undefined,
+): Partial<SubtypeMethods> | undefined => {
+  const map = methodMaps.find((method) => method.subtype === subtype);
+  return map?.methods;
+};
+
+export const addSubtypeMethods = (
+  subtype: Subtype,
+  methods: Partial<SubtypeMethods>,
+) => {
+  if (!subtypeNames.includes(subtype)) {
+    return;
+  }
+  if (!methodMaps.find((method) => method.subtype === subtype)) {
+    methodMaps.push({ subtype, methods });
+  }
+};
+
+// Callback to re-render subtyped `ExcalidrawElement`s after completing
+// async loading of the subtype.
+export type SubtypeLoadedCb = (hasSubtype: SubtypeCheckFn) => void;
+export type SubtypeCheckFn = (element: ExcalidrawElement) => boolean;
+
+// Functions to prepare subtypes for use
+export type SubtypePrepFn = (onSubtypeLoaded?: SubtypeLoadedCb) => {
+  methods: Partial<SubtypeMethods>;
+};
+
+// This is the main method to set up the subtype.  The optional
+// `onSubtypeLoaded` callback may be used to re-render subtyped
+// `ExcalidrawElement`s after the subtype has finished async loading.
+export const prepareSubtype = (
+  record: SubtypeRecord,
+  subtypePrepFn: SubtypePrepFn,
+  onSubtypeLoaded?: SubtypeLoadedCb,
+): { methods: Partial<SubtypeMethods> } => {
+  const map = getSubtypeMethods(record.subtype);
+  if (map) {
+    return { methods: map };
+  }
+
+  // Check for undefined/null subtypes and parentTypes
+  if (
+    record.subtype === undefined ||
+    record.subtype === "" ||
+    record.parents === undefined ||
+    record.parents.length === 0
+  ) {
+    return { methods: {} };
+  }
+
+  // Register the types
+  const subtype = record.subtype;
+  subtypeNames = [...subtypeNames, subtype];
+  record.parents.forEach((parentType) => {
+    parentTypeMap = [...parentTypeMap, { subtype, parentType }];
+  });
+
+  // Prepare the subtype
+  const { methods } = subtypePrepFn(onSubtypeLoaded);
+
+  // Register the subtype's methods
+  addSubtypeMethods(record.subtype, methods);
+  return { methods };
+};
+
+// Ensure all subtypes are loaded before continuing, eg to
+// redraw text element bounding boxes correctly.
+export const ensureSubtypesLoadedForElements = async (
+  elements: readonly ExcalidrawElement[],
+  callback?: () => void,
+) => {
+  // Only ensure the loading of subtypes which are actually needed.
+  // We don't want to be held up by eg downloading the MathJax SVG fonts
+  // if we don't actually need them yet.
+  const subtypesUsed = [] as Subtype[];
+  elements.forEach((el) => {
+    if (
+      "subtype" in el &&
+      el.subtype !== undefined &&
+      !subtypesUsed.includes(el.subtype)
+    ) {
+      subtypesUsed.push(el.subtype);
+    }
+  });
+  await ensureSubtypesLoaded(subtypesUsed, callback);
+};
+
+export const ensureSubtypesLoaded = async (
+  subtypes: Subtype[],
+  callback?: () => void,
+) => {
+  // Use a for loop so we can do `await map.ensureLoaded()`
+  for (let i = 0; i < subtypes.length; i++) {
+    const subtype = subtypes[i];
+    // Should be defined if prepareSubtype() has run
+    const map = getSubtypeMethods(subtype);
+    if (map?.ensureLoaded) {
+      await map.ensureLoaded();
+    }
+  }
+  if (callback) {
+    callback();
+  }
+};
+
+// Call this method after finishing any async loading for
+// subtypes of ExcalidrawElement if the newly loaded code
+// would change the rendering.
+export const checkRefreshOnSubtypeLoad = (
+  hasSubtype: SubtypeCheckFn,
+  elements: readonly ExcalidrawElement[],
+) => {
+  let refreshNeeded = false;
+  const scenes: Scene[] = [];
+  getNonDeletedElements(elements).forEach((element) => {
+    // If the element is of the subtype that was just
+    // registered, update the element's dimensions, mark the
+    // element for a re-render, and indicate the scene needs a refresh.
+    if (hasSubtype(element)) {
+      ShapeCache.delete(element);
+      if (isTextElement(element)) {
+        redrawTextBoundingBox(element, getContainerElement(element));
+      }
+      refreshNeeded = true;
+      const scene = Scene.getScene(element);
+      if (scene && !scenes.includes(scene)) {
+        // Store in case we have multiple scenes
+        scenes.push(scene);
+      }
+    }
+  });
+  // Only inform each scene once
+  scenes.forEach((scene) => scene.informMutation());
+  return refreshNeeded;
+};

+ 39 - 20
src/element/textElement.ts

@@ -1,3 +1,4 @@
+import { getSubtypeMethods, SubtypeMethods } from "./subtypes";
 import { getFontString, arrayToMap, isTestEnv } from "../utils";
 import {
   ExcalidrawElement,
@@ -36,6 +37,30 @@ import {
 } from "./textWysiwyg";
 import { ExtractSetType } from "../utility-types";
 
+export const measureTextElement = function (element, next) {
+  const map = getSubtypeMethods(element.subtype);
+  if (map?.measureText) {
+    return map.measureText(element, next);
+  }
+
+  const fontSize = next?.fontSize ?? element.fontSize;
+  const font = getFontString({ fontSize, fontFamily: element.fontFamily });
+  const text = next?.text ?? element.text;
+  return measureText(text, font, element.lineHeight);
+} as SubtypeMethods["measureText"];
+
+export const wrapTextElement = function (element, containerWidth, next) {
+  const map = getSubtypeMethods(element.subtype);
+  if (map?.wrapText) {
+    return map.wrapText(element, containerWidth, next);
+  }
+
+  const fontSize = next?.fontSize ?? element.fontSize;
+  const font = getFontString({ fontSize, fontFamily: element.fontFamily });
+  const text = next?.text ?? element.originalText;
+  return wrapText(text, font, containerWidth);
+} as SubtypeMethods["wrapText"];
+
 export const normalizeText = (text: string) => {
   return (
     text
@@ -68,22 +93,24 @@ export const redrawTextBoundingBox = (
 
   if (container) {
     maxWidth = getBoundTextMaxWidth(container, textElement);
-    boundTextUpdates.text = wrapText(
-      textElement.originalText,
-      getFontString(textElement),
-      maxWidth,
-    );
+    boundTextUpdates.text = wrapTextElement(textElement, maxWidth);
   }
-  const metrics = measureText(
-    boundTextUpdates.text,
-    getFontString(textElement),
-    textElement.lineHeight,
-  );
+  const metrics = measureTextElement(textElement, {
+    text: boundTextUpdates.text,
+  });
 
   boundTextUpdates.width = metrics.width;
   boundTextUpdates.height = metrics.height;
   boundTextUpdates.baseline = metrics.baseline;
 
+  // Maintain coordX for non left-aligned text in case the width has changed
+  if (!container) {
+    if (textElement.textAlign === TEXT_ALIGN.RIGHT) {
+      boundTextUpdates.x += textElement.width - metrics.width;
+    } else if (textElement.textAlign === TEXT_ALIGN.CENTER) {
+      boundTextUpdates.x += textElement.width / 2 - metrics.width / 2;
+    }
+  }
   if (container) {
     const maxContainerHeight = getBoundTextMaxHeight(
       container,
@@ -196,17 +223,9 @@ export const handleBindTextResize = (
       (transformHandleType !== "n" && transformHandleType !== "s")
     ) {
       if (text) {
-        text = wrapText(
-          textElement.originalText,
-          getFontString(textElement),
-          maxWidth,
-        );
+        text = wrapTextElement(textElement, maxWidth);
       }
-      const metrics = measureText(
-        text,
-        getFontString(textElement),
-        textElement.lineHeight,
-      );
+      const metrics = measureTextElement(textElement, { text });
       nextHeight = metrics.height;
       nextWidth = metrics.width;
       nextBaseLine = metrics.baseline;

+ 57 - 7
src/element/textWysiwyg.tsx

@@ -26,6 +26,7 @@ import {
   getContainerElement,
   getTextElementAngle,
   getTextWidth,
+  measureText,
   normalizeText,
   redrawTextBoundingBox,
   wrapText,
@@ -43,8 +44,10 @@ import { actionZoomIn, actionZoomOut } from "../actions/actionCanvas";
 import App from "../components/App";
 import { LinearElementEditor } from "./linearElementEditor";
 import { parseClipboard } from "../clipboard";
+import { SubtypeMethods, getSubtypeMethods } from "./subtypes";
 
 const getTransform = (
+  offsetX: number,
   width: number,
   height: number,
   angle: number,
@@ -62,7 +65,8 @@ const getTransform = (
   if (height > maxHeight && zoom.value !== 1) {
     translateY = (maxHeight * (zoom.value - 1)) / 2;
   }
-  return `translate(${translateX}px, ${translateY}px) scale(${zoom.value}) rotate(${degree}deg)`;
+  const offset = offsetX !== 0 ? ` translate(${offsetX}px, 0px)` : "";
+  return `translate(${translateX}px, ${translateY}px) scale(${zoom.value}) rotate(${degree}deg)${offset}`;
 };
 
 const originalContainerCache: {
@@ -97,6 +101,14 @@ export const getOriginalContainerHeightFromCache = (
   return originalContainerCache[id]?.height ?? null;
 };
 
+const getEditorStyle = function (element) {
+  const map = getSubtypeMethods(element.subtype);
+  if (map?.getEditorStyle) {
+    return map.getEditorStyle(element);
+  }
+  return {};
+} as SubtypeMethods["getEditorStyle"];
+
 export const textWysiwyg = ({
   id,
   onChange,
@@ -156,11 +168,24 @@ export const textWysiwyg = ({
       const container = getContainerElement(updatedTextElement);
       let maxWidth = updatedTextElement.width;
 
-      let maxHeight = updatedTextElement.height;
-      let textElementWidth = updatedTextElement.width;
+      // Editing metrics
+      const eMetrics = measureText(
+        container && updatedTextElement.containerId
+          ? wrapText(
+              updatedTextElement.originalText,
+              getFontString(updatedTextElement),
+              getBoundTextMaxWidth(container),
+            )
+          : updatedTextElement.originalText,
+        getFontString(updatedTextElement),
+        updatedTextElement.lineHeight,
+      );
+
+      let maxHeight = eMetrics.height;
+      let textElementWidth = Math.max(updatedTextElement.width, eMetrics.width);
       // Set to element height by default since that's
       // what is going to be used for unbounded text
-      const textElementHeight = updatedTextElement.height;
+      const textElementHeight = Math.max(updatedTextElement.height, maxHeight);
 
       if (container && updatedTextElement.containerId) {
         if (isArrowElement(container)) {
@@ -246,13 +271,35 @@ export const textWysiwyg = ({
         editable.selectionEnd = editable.value.length - diff;
       }
 
+      let transformWidth = updatedTextElement.width;
       if (!container) {
         maxWidth = (appState.width - 8 - viewportX) / appState.zoom.value;
         textElementWidth = Math.min(textElementWidth, maxWidth);
       } else {
         textElementWidth += 0.5;
+        transformWidth += 0.5;
       }
 
+      // Horizontal offset in case updatedTextElement has a non-WYSIWYG subtype
+      const offWidth = container
+        ? Math.min(
+            0,
+            updatedTextElement.width - Math.min(maxWidth, eMetrics.width),
+          )
+        : Math.min(maxWidth, updatedTextElement.width) -
+          Math.min(maxWidth, eMetrics.width);
+      const offsetX =
+        textAlign === "right"
+          ? offWidth
+          : textAlign === "center"
+          ? offWidth / 2
+          : 0;
+      const { width: w, height: h } = updatedTextElement;
+      const transformOrigin =
+        updatedTextElement.width !== eMetrics.width ||
+        updatedTextElement.height !== eMetrics.height
+          ? { transformOrigin: `${w / 2}px ${h / 2}px` }
+          : {};
       let lineHeight = updatedTextElement.lineHeight;
 
       // In Safari the font size gets rounded off when rendering hence calculating the line height by rounding off font size
@@ -270,13 +317,15 @@ export const textWysiwyg = ({
         font: getFontString(updatedTextElement),
         // must be defined *after* font ¯\_(ツ)_/¯
         lineHeight,
-        width: `${textElementWidth}px`,
+        width: `${Math.min(textElementWidth, maxWidth)}px`,
         height: `${textElementHeight}px`,
         left: `${viewportX}px`,
         top: `${viewportY}px`,
+        ...transformOrigin,
         transform: getTransform(
-          textElementWidth,
-          textElementHeight,
+          offsetX,
+          transformWidth,
+          updatedTextElement.height,
           getTextElementAngle(updatedTextElement),
           appState,
           maxWidth,
@@ -334,6 +383,7 @@ export const textWysiwyg = ({
     whiteSpace,
     overflowWrap: "break-word",
     boxSizing: "content-box",
+    ...getEditorStyle(element),
   });
   editable.value = element.originalText;
   updateWysiwygStyle();

+ 1 - 0
src/element/types.ts

@@ -65,6 +65,7 @@ type _ExcalidrawElementBase = Readonly<{
   updated: number;
   link: string | null;
   locked: boolean;
+  subtype?: string;
   customData?: Record<string, any>;
 }>;
 

+ 12 - 0
src/renderer/renderElement.ts

@@ -31,6 +31,7 @@ import {
   InteractiveCanvasAppState,
 } from "../types";
 import { getDefaultAppState } from "../appState";
+import { getSubtypeMethods } from "../element/subtypes";
 import {
   BOUND_TEXT_PADDING,
   FRAME_STYLE,
@@ -264,6 +265,12 @@ const drawElementOnCanvas = (
 ) => {
   context.globalAlpha =
     ((getContainingFrame(element)?.opacity ?? 100) * element.opacity) / 10000;
+  const map = getSubtypeMethods(element.subtype);
+  if (map?.render) {
+    map.render(element, context);
+    context.globalAlpha = 1;
+    return;
+  }
   switch (element.type) {
     case "rectangle":
     case "embeddable":
@@ -897,6 +904,11 @@ export const renderElementToSvg = (
     root = anchorTag;
   }
 
+  const map = getSubtypeMethods(element.subtype);
+  if (map?.renderSvg) {
+    map.renderSvg(svgRoot, root, element, { offsetX, offsetY });
+    return;
+  }
   const opacity =
     ((getContainingFrame(element)?.opacity ?? 100) * element.opacity) / 10000;
 

+ 25 - 0
src/tests/helpers/api.ts

@@ -16,6 +16,14 @@ import util from "util";
 import path from "path";
 import { getMimeType } from "../../data/blob";
 import {
+  SubtypeLoadedCb,
+  SubtypePrepFn,
+  SubtypeRecord,
+  checkRefreshOnSubtypeLoad,
+  prepareSubtype,
+} from "../../element/subtypes";
+import {
+  maybeGetSubtypeProps,
   newEmbeddableElement,
   newFrameElement,
   newFreeDrawElement,
@@ -32,6 +40,16 @@ const readFile = util.promisify(fs.readFile);
 const { h } = window;
 
 export class API {
+  static addSubtype = (record: SubtypeRecord, subtypePrepFn: SubtypePrepFn) => {
+    const subtypeLoadedCb: SubtypeLoadedCb = (hasSubtype) => {
+      if (checkRefreshOnSubtypeLoad(hasSubtype, h.elements)) {
+        h.app.refresh();
+      }
+    };
+    const prep = prepareSubtype(record, subtypePrepFn, subtypeLoadedCb);
+    return prep;
+  };
+
   static setSelectedElements = (elements: ExcalidrawElement[]) => {
     h.setState({
       selectedElementIds: elements.reduce((acc, element) => {
@@ -113,6 +131,8 @@ export class API {
     verticalAlign?: T extends "text"
       ? ExcalidrawTextElement["verticalAlign"]
       : never;
+    subtype?: ExcalidrawElement["subtype"];
+    customData?: ExcalidrawElement["customData"];
     boundElements?: ExcalidrawGenericElement["boundElements"];
     containerId?: T extends "text"
       ? ExcalidrawTextElement["containerId"]
@@ -141,6 +161,10 @@ export class API {
 
     const appState = h?.state || getDefaultAppState();
 
+    const custom = maybeGetSubtypeProps({
+      subtype: rest.subtype,
+      customData: rest.customData,
+    });
     const base: Omit<
       ExcalidrawGenericElement,
       | "id"
@@ -155,6 +179,7 @@ export class API {
       | "link"
       | "updated"
     > = {
+      ...custom,
       x,
       y,
       frameId: rest.frameId ?? null,

+ 395 - 0
src/tests/subtypes.test.tsx

@@ -0,0 +1,395 @@
+import { vi } from "vitest";
+import {
+  SubtypeLoadedCb,
+  SubtypeRecord,
+  SubtypeMethods,
+  SubtypePrepFn,
+  addSubtypeMethods,
+  ensureSubtypesLoadedForElements,
+  getSubtypeMethods,
+  getSubtypeNames,
+} from "../element/subtypes";
+
+import { render } from "./test-utils";
+import { API } from "./helpers/api";
+import { Excalidraw, FONT_FAMILY } from "../packages/excalidraw/index";
+
+import {
+  ExcalidrawElement,
+  ExcalidrawTextElement,
+  FontString,
+} from "../element/types";
+import { getFontString } from "../utils";
+import * as textElementUtils from "../element/textElement";
+import { isTextElement } from "../element";
+import { mutateElement, newElementWith } from "../element/mutateElement";
+
+const MW = 200;
+const TWIDTH = 200;
+const THEIGHT = 20;
+const TBASELINE = 0;
+const FONTSIZE = 20;
+const DBFONTSIZE = 40;
+const TRFONTSIZE = 60;
+
+const test2: SubtypeRecord = {
+  subtype: "test2",
+  parents: ["text"],
+};
+
+const test3: SubtypeRecord = {
+  subtype: "test3",
+  parents: ["text", "line"],
+};
+
+const prepareNullSubtype = function () {
+  const methods = {} as SubtypeMethods;
+  methods.clean = cleanTest2ElementUpdate;
+  methods.measureText = measureTest2;
+  methods.wrapText = wrapTest2;
+
+  return { methods };
+} as SubtypePrepFn;
+
+const cleanTest2ElementUpdate = function (updates) {
+  const oldUpdates = {};
+  for (const key in updates) {
+    if (key !== "fontFamily") {
+      (oldUpdates as any)[key] = (updates as any)[key];
+    }
+  }
+  (updates as any).fontFamily = FONT_FAMILY.Cascadia;
+  return oldUpdates;
+} as SubtypeMethods["clean"];
+
+let test2Loaded = false;
+
+const ensureLoadedTest2: SubtypeMethods["ensureLoaded"] = async (callback) => {
+  test2Loaded = true;
+  if (onTest2Loaded) {
+    onTest2Loaded((el) => isTextElement(el) && el.subtype === test2.subtype);
+  }
+  if (callback) {
+    callback();
+  }
+};
+
+const measureTest2: SubtypeMethods["measureText"] = function (element, next) {
+  const text = next?.text ?? element.text;
+  const customData = next?.customData ?? {};
+  const fontSize = customData.triple
+    ? TRFONTSIZE
+    : next?.fontSize ?? element.fontSize;
+  const fontFamily = element.fontFamily;
+  const fontString = getFontString({ fontSize, fontFamily });
+  const lineHeight = element.lineHeight;
+  const metrics = textElementUtils.measureText(text, fontString, lineHeight);
+  const width = test2Loaded
+    ? metrics.width * 2
+    : Math.max(metrics.width - 10, 0);
+  const height = test2Loaded
+    ? metrics.height * 2
+    : Math.max(metrics.height - 5, 0);
+  return { width, height, baseline: 1 };
+};
+
+const wrapTest2: SubtypeMethods["wrapText"] = function (
+  element,
+  maxWidth,
+  next,
+) {
+  const text = next?.text ?? element.originalText;
+  if (next?.customData && next?.customData.triple === true) {
+    return `${text.split(" ").join("\n")}\nHELLO WORLD.`;
+  }
+  if (next?.fontSize === DBFONTSIZE) {
+    return `${text.split(" ").join("\n")}\nHELLO World.`;
+  }
+  return `${text.split(" ").join("\n")}\nHello world.`;
+};
+
+let onTest2Loaded: SubtypeLoadedCb | undefined;
+
+const prepareTest2Subtype = function (onSubtypeLoaded) {
+  const methods = {
+    clean: cleanTest2ElementUpdate,
+    ensureLoaded: ensureLoadedTest2,
+    measureText: measureTest2,
+    wrapText: wrapTest2,
+  } as SubtypeMethods;
+
+  onTest2Loaded = onSubtypeLoaded;
+
+  return { methods };
+} as SubtypePrepFn;
+
+const prepareTest3Subtype = function () {
+  const methods = {} as SubtypeMethods;
+
+  return { methods };
+} as SubtypePrepFn;
+
+const { h } = window;
+
+describe("subtype registration", () => {
+  it("should check for invalid subtype or parents", async () => {
+    await render(<Excalidraw />, {});
+    // Define invalid subtype records
+    const null1 = {} as SubtypeRecord;
+    const null2 = { subtype: "" } as SubtypeRecord;
+    const null3 = { subtype: "null" } as SubtypeRecord;
+    const null4 = { subtype: "null", parents: [] } as SubtypeRecord;
+    // Try registering the invalid subtypes
+    const prepN1 = API.addSubtype(null1, prepareNullSubtype);
+    const prepN2 = API.addSubtype(null2, prepareNullSubtype);
+    const prepN3 = API.addSubtype(null3, prepareNullSubtype);
+    const prepN4 = API.addSubtype(null4, prepareNullSubtype);
+    // Verify the guards in `prepareSubtype` worked
+    expect(prepN1).toStrictEqual({ methods: {} });
+    expect(prepN2).toStrictEqual({ methods: {} });
+    expect(prepN3).toStrictEqual({ methods: {} });
+    expect(prepN4).toStrictEqual({ methods: {} });
+  });
+  it("should return subtype methods correctly", async () => {
+    // Check initial registration works
+    let prep2 = API.addSubtype(test2, prepareTest2Subtype);
+    expect(prep2.methods).toStrictEqual({
+      clean: cleanTest2ElementUpdate,
+      ensureLoaded: ensureLoadedTest2,
+      measureText: measureTest2,
+      wrapText: wrapTest2,
+    });
+    // Check repeat registration fails
+    prep2 = API.addSubtype(test2, prepareNullSubtype);
+    expect(prep2.methods).toStrictEqual({
+      clean: cleanTest2ElementUpdate,
+      ensureLoaded: ensureLoadedTest2,
+      measureText: measureTest2,
+      wrapText: wrapTest2,
+    });
+
+    // Check initial registration works
+    let prep3 = API.addSubtype(test3, prepareTest3Subtype);
+    expect(prep3.methods).toStrictEqual({});
+    // Check repeat registration fails
+    prep3 = API.addSubtype(test3, prepareNullSubtype);
+    expect(prep3.methods).toStrictEqual({});
+  });
+});
+
+describe("subtypes", () => {
+  it("should correctly register", async () => {
+    const subtypes = getSubtypeNames();
+    expect(subtypes).toContain(test2.subtype);
+    expect(subtypes).toContain(test3.subtype);
+  });
+  it("should return subtype methods", async () => {
+    expect(getSubtypeMethods(undefined)).toBeUndefined();
+    const test2Methods = getSubtypeMethods(test2.subtype);
+    expect(test2Methods?.clean).toStrictEqual(cleanTest2ElementUpdate);
+    expect(test2Methods?.ensureLoaded).toStrictEqual(ensureLoadedTest2);
+    expect(test2Methods?.measureText).toStrictEqual(measureTest2);
+    expect(test2Methods?.render).toBeUndefined();
+    expect(test2Methods?.renderSvg).toBeUndefined();
+    expect(test2Methods?.wrapText).toStrictEqual(wrapTest2);
+  });
+  it("should not overwrite subtype methods", async () => {
+    addSubtypeMethods(test2.subtype, {});
+    addSubtypeMethods(test3.subtype, { clean: cleanTest2ElementUpdate });
+    const test2Methods = getSubtypeMethods(test2.subtype);
+    expect(test2Methods?.measureText).toStrictEqual(measureTest2);
+    expect(test2Methods?.wrapText).toStrictEqual(wrapTest2);
+    const test3Methods = getSubtypeMethods(test3.subtype);
+    expect(test3Methods?.clean).toBeUndefined();
+  });
+  it("should apply to ExcalidrawElements", async () => {
+    const elements = [
+      API.createElement({ type: "text", id: "A", subtype: test3.subtype }),
+      API.createElement({ type: "line", id: "B", subtype: test3.subtype }),
+    ];
+    await render(<Excalidraw />, { localStorageData: { elements } });
+    elements.forEach((el) => expect(el.subtype).toBe(test3.subtype));
+  });
+  it("should enforce prop value restrictions", async () => {
+    const elements = [
+      API.createElement({
+        type: "text",
+        id: "A",
+        subtype: test2.subtype,
+        fontFamily: FONT_FAMILY.Virgil,
+      }),
+      API.createElement({
+        type: "text",
+        id: "B",
+        fontFamily: FONT_FAMILY.Virgil,
+      }),
+    ];
+    await render(<Excalidraw />, { localStorageData: { elements } });
+    elements.forEach((el) => {
+      if (el.subtype === test2.subtype) {
+        expect(el.fontFamily).toBe(FONT_FAMILY.Cascadia);
+      } else {
+        expect(el.fontFamily).toBe(FONT_FAMILY.Virgil);
+      }
+    });
+  });
+  it("should consider enforced prop values in version increments", async () => {
+    const rectA = API.createElement({
+      type: "text",
+      id: "A",
+      subtype: test2.subtype,
+      fontFamily: FONT_FAMILY.Virgil,
+      fontSize: 10,
+    });
+    const rectB = API.createElement({
+      type: "text",
+      id: "B",
+      subtype: test2.subtype,
+      fontFamily: FONT_FAMILY.Virgil,
+      fontSize: 10,
+    });
+    // Initial element creation checks
+    expect(rectA.fontFamily).toBe(FONT_FAMILY.Cascadia);
+    expect(rectB.fontFamily).toBe(FONT_FAMILY.Cascadia);
+    expect(rectA.version).toBe(1);
+    expect(rectB.version).toBe(1);
+    // Check that attempting to set prop values not permitted by the subtype
+    // doesn't increment element versions
+    mutateElement(rectA, { fontFamily: FONT_FAMILY.Helvetica });
+    mutateElement(rectB, { fontFamily: FONT_FAMILY.Helvetica, fontSize: 20 });
+    expect(rectA.version).toBe(1);
+    expect(rectB.version).toBe(2);
+    // Check that element versions don't increment when creating new elements
+    // while attempting to use prop values not permitted by the subtype
+    // First check based on `rectA` (unsuccessfully mutated)
+    const rectC = newElementWith(rectA, { fontFamily: FONT_FAMILY.Virgil });
+    const rectD = newElementWith(rectA, {
+      fontFamily: FONT_FAMILY.Virgil,
+      fontSize: 15,
+    });
+    expect(rectC.version).toBe(1);
+    expect(rectD.version).toBe(2);
+    // Then check based on `rectB` (successfully mutated)
+    const rectE = newElementWith(rectB, { fontFamily: FONT_FAMILY.Virgil });
+    const rectF = newElementWith(rectB, {
+      fontFamily: FONT_FAMILY.Virgil,
+      fontSize: 15,
+    });
+    expect(rectE.version).toBe(2);
+    expect(rectF.version).toBe(3);
+  });
+  it("should call custom text methods", async () => {
+    const testString = "A quick brown fox jumps over the lazy dog.";
+    const elements = [
+      API.createElement({
+        type: "text",
+        id: "A",
+        subtype: test2.subtype,
+        text: testString,
+        fontSize: FONTSIZE,
+      }),
+    ];
+    await render(<Excalidraw />, { localStorageData: { elements } });
+    const mockMeasureText = (text: string, font: FontString) => {
+      if (text === testString) {
+        let multiplier = 1;
+        if (font.includes(`${DBFONTSIZE}`)) {
+          multiplier = 2;
+        }
+        if (font.includes(`${TRFONTSIZE}`)) {
+          multiplier = 3;
+        }
+        const width = multiplier * TWIDTH;
+        const height = multiplier * THEIGHT;
+        const baseline = multiplier * TBASELINE;
+        return { width, height, baseline };
+      }
+      return { width: 1, height: 0, baseline: 0 };
+    };
+
+    vi.spyOn(textElementUtils, "measureText").mockImplementation(
+      mockMeasureText,
+    );
+
+    elements.forEach((el) => {
+      if (isTextElement(el)) {
+        // First test with `ExcalidrawTextElement.text`
+        const metrics = textElementUtils.measureTextElement(el);
+        expect(metrics).toStrictEqual({
+          width: TWIDTH - 10,
+          height: THEIGHT - 5,
+          baseline: TBASELINE + 1,
+        });
+        const wrappedText = textElementUtils.wrapTextElement(el, MW);
+        expect(wrappedText).toEqual(
+          `${testString.split(" ").join("\n")}\nHello world.`,
+        );
+
+        // Now test with modified text in `next`
+        let next: {
+          text?: string;
+          fontSize?: number;
+          customData?: Record<string, any>;
+        } = {
+          text: "Hello world.",
+        };
+        const nextMetrics = textElementUtils.measureTextElement(el, next);
+        expect(nextMetrics).toStrictEqual({ width: 0, height: 0, baseline: 1 });
+        const nextWrappedText = textElementUtils.wrapTextElement(el, MW, next);
+        expect(nextWrappedText).toEqual("Hello\nworld.\nHello world.");
+
+        // Now test modified fontSizes in `next`
+        next = { fontSize: DBFONTSIZE };
+        const nextFM = textElementUtils.measureTextElement(el, next);
+        expect(nextFM).toStrictEqual({
+          width: 2 * TWIDTH - 10,
+          height: 2 * THEIGHT - 5,
+          baseline: 2 * TBASELINE + 1,
+        });
+        const nextFWrText = textElementUtils.wrapTextElement(el, MW, next);
+        expect(nextFWrText).toEqual(
+          `${testString.split(" ").join("\n")}\nHELLO World.`,
+        );
+
+        // Now test customData in `next`
+        next = { customData: { triple: true } };
+        const nextCD = textElementUtils.measureTextElement(el, next);
+        expect(nextCD).toStrictEqual({
+          width: 3 * TWIDTH - 10,
+          height: 3 * THEIGHT - 5,
+          baseline: 3 * TBASELINE + 1,
+        });
+        const nextCDWrText = textElementUtils.wrapTextElement(el, MW, next);
+        expect(nextCDWrText).toEqual(
+          `${testString.split(" ").join("\n")}\nHELLO WORLD.`,
+        );
+      }
+    });
+  });
+});
+describe("subtype loading", () => {
+  let elements: ExcalidrawElement[];
+  beforeEach(async () => {
+    const testString = "A quick brown fox jumps over the lazy dog.";
+    elements = [
+      API.createElement({
+        type: "text",
+        id: "A",
+        subtype: test2.subtype,
+        text: testString,
+      }),
+    ];
+    await render(<Excalidraw />, { localStorageData: { elements } });
+    h.elements = elements;
+  });
+  it("should redraw text bounding boxes", async () => {
+    h.setState({ selectedElementIds: { A: true } });
+    const el = h.elements[0] as ExcalidrawTextElement;
+    expect(el.width).toEqual(100);
+    expect(el.height).toEqual(100);
+    ensureSubtypesLoadedForElements(elements);
+    expect(el.width).toEqual(TWIDTH * 2);
+    expect(el.height).toEqual(THEIGHT * 2);
+    expect(el.baseline).toEqual(TBASELINE + 1);
+  });
+});