api.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import fs from "fs";
  2. import path from "path";
  3. import util from "util";
  4. import { pointFrom, type LocalPoint, type Radians } from "@excalidraw/math";
  5. import { DEFAULT_VERTICAL_ALIGN, ROUNDNESS, assertNever } from "@excalidraw/common";
  6. import { mutateElement } from "@excalidraw/element/mutateElement";
  7. import {
  8. newArrowElement,
  9. newElement,
  10. newEmbeddableElement,
  11. newFrameElement,
  12. newFreeDrawElement,
  13. newIframeElement,
  14. newImageElement,
  15. newLinearElement,
  16. newMagicFrameElement,
  17. newTextElement,
  18. } from "@excalidraw/element/newElement";
  19. import { isLinearElementType } from "@excalidraw/element/typeChecks";
  20. import { getSelectedElements } from "@excalidraw/element/selection";
  21. import { selectGroupsForSelectedElements } from "@excalidraw/element/groups";
  22. import type {
  23. ExcalidrawElement,
  24. ExcalidrawGenericElement,
  25. ExcalidrawTextElement,
  26. ExcalidrawLinearElement,
  27. ExcalidrawFreeDrawElement,
  28. ExcalidrawImageElement,
  29. FileId,
  30. ExcalidrawFrameElement,
  31. ExcalidrawElementType,
  32. ExcalidrawMagicFrameElement,
  33. ExcalidrawElbowArrowElement,
  34. ExcalidrawArrowElement,
  35. FixedSegment,
  36. } from "@excalidraw/element/types";
  37. import type { Mutable } from "@excalidraw/common/utility-types";
  38. import { getMimeType } from "../../data/blob";
  39. import { createTestHook } from "../../components/App";
  40. import { getDefaultAppState } from "../../appState";
  41. import { GlobalTestState, createEvent, fireEvent, act } from "../test-utils";
  42. import type { Action } from "../../actions/types";
  43. import type App from "../../components/App";
  44. import type { AppState } from "../../types";
  45. const readFile = util.promisify(fs.readFile);
  46. // so that window.h is available when App.tsx is not imported as well.
  47. createTestHook();
  48. const { h } = window;
  49. export class API {
  50. static updateScene: InstanceType<typeof App>["updateScene"] = (...args) => {
  51. act(() => {
  52. h.app.updateScene(...args);
  53. });
  54. };
  55. static setAppState: React.Component<any, AppState>["setState"] = (
  56. state,
  57. cb,
  58. ) => {
  59. act(() => {
  60. h.setState(state, cb);
  61. });
  62. };
  63. static setElements = (elements: readonly ExcalidrawElement[]) => {
  64. act(() => {
  65. h.elements = elements;
  66. });
  67. };
  68. static setSelectedElements = (elements: ExcalidrawElement[], editingGroupId?: string | null) => {
  69. act(() => {
  70. h.setState({
  71. ...selectGroupsForSelectedElements(
  72. {
  73. editingGroupId: editingGroupId ?? null,
  74. selectedElementIds: elements.reduce((acc, element) => {
  75. acc[element.id] = true;
  76. return acc;
  77. }, {} as Record<ExcalidrawElement["id"], true>),
  78. },
  79. elements,
  80. h.state,
  81. h.app,
  82. )
  83. });
  84. });
  85. };
  86. // eslint-disable-next-line prettier/prettier
  87. static updateElement = <T extends ExcalidrawElement>(
  88. ...args: Parameters<typeof mutateElement<T>>
  89. ) => {
  90. act(() => {
  91. mutateElement<T>(...args);
  92. });
  93. };
  94. static getSelectedElements = (
  95. includeBoundTextElement: boolean = false,
  96. includeElementsInFrames: boolean = false,
  97. ): ExcalidrawElement[] => {
  98. return getSelectedElements(h.elements, h.state, {
  99. includeBoundTextElement,
  100. includeElementsInFrames,
  101. });
  102. };
  103. static getSelectedElement = (): ExcalidrawElement => {
  104. const selectedElements = API.getSelectedElements();
  105. if (selectedElements.length !== 1) {
  106. throw new Error(
  107. `expected 1 selected element; got ${selectedElements.length}`,
  108. );
  109. }
  110. return selectedElements[0];
  111. };
  112. static getUndoStack = () => {
  113. // @ts-ignore
  114. return h.history.undoStack;
  115. };
  116. static getRedoStack = () => {
  117. // @ts-ignore
  118. return h.history.redoStack;
  119. };
  120. static getSnapshot = () => {
  121. return Array.from(h.store.snapshot.elements.values());
  122. };
  123. static clearSelection = () => {
  124. act(() => {
  125. // @ts-ignore
  126. h.app.clearSelection(null);
  127. });
  128. expect(API.getSelectedElements().length).toBe(0);
  129. };
  130. static getElement = <T extends ExcalidrawElement>(element: T): T => {
  131. return h.app.scene.getElementsMapIncludingDeleted().get(element.id) as T || element;
  132. }
  133. static createElement = <
  134. T extends Exclude<ExcalidrawElementType, "selection"> = "rectangle",
  135. >({
  136. // @ts-ignore
  137. type = "rectangle",
  138. id,
  139. x = 0,
  140. y = x,
  141. width = 100,
  142. height = width,
  143. isDeleted = false,
  144. groupIds = [],
  145. ...rest
  146. }: {
  147. type?: T;
  148. x?: number;
  149. y?: number;
  150. height?: number;
  151. width?: number;
  152. angle?: number;
  153. id?: string;
  154. isDeleted?: boolean;
  155. frameId?: ExcalidrawElement["id"] | null;
  156. index?: ExcalidrawElement["index"];
  157. groupIds?: ExcalidrawElement["groupIds"];
  158. // generic element props
  159. strokeColor?: ExcalidrawGenericElement["strokeColor"];
  160. backgroundColor?: ExcalidrawGenericElement["backgroundColor"];
  161. fillStyle?: ExcalidrawGenericElement["fillStyle"];
  162. strokeWidth?: ExcalidrawGenericElement["strokeWidth"];
  163. strokeStyle?: ExcalidrawGenericElement["strokeStyle"];
  164. roundness?: ExcalidrawGenericElement["roundness"];
  165. roughness?: ExcalidrawGenericElement["roughness"];
  166. opacity?: ExcalidrawGenericElement["opacity"];
  167. // text props
  168. text?: T extends "text" ? ExcalidrawTextElement["text"] : never;
  169. fontSize?: T extends "text" ? ExcalidrawTextElement["fontSize"] : never;
  170. fontFamily?: T extends "text" ? ExcalidrawTextElement["fontFamily"] : never;
  171. textAlign?: T extends "text" ? ExcalidrawTextElement["textAlign"] : never;
  172. verticalAlign?: T extends "text"
  173. ? ExcalidrawTextElement["verticalAlign"]
  174. : never;
  175. boundElements?: ExcalidrawGenericElement["boundElements"];
  176. containerId?: T extends "text"
  177. ? ExcalidrawTextElement["containerId"]
  178. : never;
  179. points?: T extends "arrow" | "line" | "freedraw" ? readonly LocalPoint[] : never;
  180. locked?: boolean;
  181. fileId?: T extends "image" ? string : never;
  182. scale?: T extends "image" ? ExcalidrawImageElement["scale"] : never;
  183. status?: T extends "image" ? ExcalidrawImageElement["status"] : never;
  184. startBinding?: T extends "arrow"
  185. ? ExcalidrawArrowElement["startBinding"] | ExcalidrawElbowArrowElement["startBinding"]
  186. : never;
  187. endBinding?: T extends "arrow"
  188. ? ExcalidrawArrowElement["endBinding"] | ExcalidrawElbowArrowElement["endBinding"]
  189. : never;
  190. startArrowhead?: T extends "arrow"
  191. ? ExcalidrawArrowElement["startArrowhead"] | ExcalidrawElbowArrowElement["startArrowhead"]
  192. : never;
  193. endArrowhead?: T extends "arrow"
  194. ? ExcalidrawArrowElement["endArrowhead"] | ExcalidrawElbowArrowElement["endArrowhead"]
  195. : never;
  196. elbowed?: boolean;
  197. fixedSegments?: FixedSegment[] | null;
  198. }): T extends "arrow" | "line"
  199. ? ExcalidrawLinearElement
  200. : T extends "freedraw"
  201. ? ExcalidrawFreeDrawElement
  202. : T extends "text"
  203. ? ExcalidrawTextElement
  204. : T extends "image"
  205. ? ExcalidrawImageElement
  206. : T extends "frame"
  207. ? ExcalidrawFrameElement
  208. : T extends "magicframe"
  209. ? ExcalidrawMagicFrameElement
  210. : ExcalidrawGenericElement => {
  211. let element: Mutable<ExcalidrawElement> = null!;
  212. const appState = h?.state || getDefaultAppState();
  213. const base: Omit<
  214. ExcalidrawGenericElement,
  215. | "id"
  216. | "type"
  217. | "version"
  218. | "versionNonce"
  219. | "isDeleted"
  220. | "groupIds"
  221. | "link"
  222. | "updated"
  223. > = {
  224. seed: 1,
  225. x,
  226. y,
  227. width,
  228. height,
  229. frameId: rest.frameId ?? null,
  230. index: rest.index ?? null,
  231. angle: (rest.angle ?? 0) as Radians,
  232. strokeColor: rest.strokeColor ?? appState.currentItemStrokeColor,
  233. backgroundColor:
  234. rest.backgroundColor ?? appState.currentItemBackgroundColor,
  235. fillStyle: rest.fillStyle ?? appState.currentItemFillStyle,
  236. strokeWidth: rest.strokeWidth ?? appState.currentItemStrokeWidth,
  237. strokeStyle: rest.strokeStyle ?? appState.currentItemStrokeStyle,
  238. roundness: (
  239. rest.roundness === undefined
  240. ? appState.currentItemRoundness === "round"
  241. : rest.roundness
  242. )
  243. ? {
  244. type: isLinearElementType(type)
  245. ? ROUNDNESS.PROPORTIONAL_RADIUS
  246. : ROUNDNESS.ADAPTIVE_RADIUS,
  247. }
  248. : null,
  249. roughness: rest.roughness ?? appState.currentItemRoughness,
  250. opacity: rest.opacity ?? appState.currentItemOpacity,
  251. boundElements: rest.boundElements ?? null,
  252. locked: rest.locked ?? false,
  253. };
  254. switch (type) {
  255. case "rectangle":
  256. case "diamond":
  257. case "ellipse":
  258. element = newElement({
  259. type: type as "rectangle" | "diamond" | "ellipse",
  260. ...base,
  261. });
  262. break;
  263. case "embeddable":
  264. element = newEmbeddableElement({
  265. type: "embeddable",
  266. ...base,
  267. });
  268. break;
  269. case "iframe":
  270. element = newIframeElement({
  271. type: "iframe",
  272. ...base,
  273. });
  274. break;
  275. case "text":
  276. const fontSize = rest.fontSize ?? appState.currentItemFontSize;
  277. const fontFamily = rest.fontFamily ?? appState.currentItemFontFamily;
  278. element = newTextElement({
  279. ...base,
  280. text: rest.text || "test",
  281. fontSize,
  282. fontFamily,
  283. textAlign: rest.textAlign ?? appState.currentItemTextAlign,
  284. verticalAlign: rest.verticalAlign ?? DEFAULT_VERTICAL_ALIGN,
  285. containerId: rest.containerId ?? undefined,
  286. });
  287. element.width = width;
  288. element.height = height;
  289. break;
  290. case "freedraw":
  291. element = newFreeDrawElement({
  292. type: type as "freedraw",
  293. simulatePressure: true,
  294. points: rest.points,
  295. ...base,
  296. });
  297. break;
  298. case "arrow":
  299. element = newArrowElement({
  300. ...base,
  301. width,
  302. height,
  303. type,
  304. points: rest.points ?? [
  305. pointFrom<LocalPoint>(0, 0),
  306. pointFrom<LocalPoint>(100, 100),
  307. ],
  308. elbowed: rest.elbowed ?? false,
  309. });
  310. break;
  311. case "line":
  312. element = newLinearElement({
  313. ...base,
  314. width,
  315. height,
  316. type,
  317. points: rest.points ?? [
  318. pointFrom<LocalPoint>(0, 0),
  319. pointFrom<LocalPoint>(100, 100),
  320. ],
  321. });
  322. break;
  323. case "image":
  324. element = newImageElement({
  325. ...base,
  326. width,
  327. height,
  328. type,
  329. fileId: (rest.fileId as string as FileId) ?? null,
  330. status: rest.status || "saved",
  331. scale: rest.scale || [1, 1],
  332. });
  333. break;
  334. case "frame":
  335. element = newFrameElement({ ...base, width, height });
  336. break;
  337. case "magicframe":
  338. element = newMagicFrameElement({ ...base, width, height });
  339. break;
  340. default:
  341. assertNever(
  342. type,
  343. `API.createElement: unimplemented element type ${type}}`,
  344. );
  345. break;
  346. }
  347. if (element.type === "arrow") {
  348. element.startBinding = rest.startBinding ?? null;
  349. element.endBinding = rest.endBinding ?? null;
  350. element.startArrowhead = rest.startArrowhead ?? null;
  351. element.endArrowhead = rest.endArrowhead ?? null;
  352. }
  353. if (id) {
  354. element.id = id;
  355. }
  356. if (isDeleted) {
  357. element.isDeleted = isDeleted;
  358. }
  359. if (groupIds) {
  360. element.groupIds = groupIds;
  361. }
  362. return element as any;
  363. };
  364. static createTextContainer = (opts?: {
  365. frameId?: ExcalidrawElement["id"];
  366. groupIds?: ExcalidrawElement["groupIds"];
  367. label?: {
  368. text?: string;
  369. frameId?: ExcalidrawElement["id"] | null;
  370. groupIds?: ExcalidrawElement["groupIds"];
  371. };
  372. }) => {
  373. const rectangle = API.createElement({
  374. type: "rectangle",
  375. frameId: opts?.frameId || null,
  376. groupIds: opts?.groupIds,
  377. });
  378. const text = API.createElement({
  379. type: "text",
  380. text: opts?.label?.text || "sample-text",
  381. width: 50,
  382. height: 20,
  383. fontSize: 16,
  384. containerId: rectangle.id,
  385. frameId:
  386. opts?.label?.frameId === undefined
  387. ? opts?.frameId ?? null
  388. : opts?.label?.frameId ?? null,
  389. groupIds: opts?.label?.groupIds === undefined
  390. ? opts?.groupIds
  391. : opts?.label?.groupIds ,
  392. });
  393. mutateElement(
  394. rectangle,
  395. {
  396. boundElements: [{ type: "text", id: text.id }],
  397. },
  398. false,
  399. );
  400. return [rectangle, text];
  401. };
  402. static createLabeledArrow = (opts?: {
  403. frameId?: ExcalidrawElement["id"];
  404. label?: {
  405. text?: string;
  406. frameId?: ExcalidrawElement["id"] | null;
  407. };
  408. }) => {
  409. const arrow = API.createElement({
  410. type: "arrow",
  411. frameId: opts?.frameId || null,
  412. });
  413. const text = API.createElement({
  414. type: "text",
  415. width: 50,
  416. height: 20,
  417. containerId: arrow.id,
  418. frameId:
  419. opts?.label?.frameId === undefined
  420. ? opts?.frameId ?? null
  421. : opts?.label?.frameId ?? null,
  422. });
  423. mutateElement(
  424. arrow,
  425. {
  426. boundElements: [{ type: "text", id: text.id }],
  427. },
  428. false,
  429. );
  430. return [arrow, text];
  431. };
  432. static readFile = async <T extends "utf8" | null>(
  433. filepath: string,
  434. encoding?: T,
  435. ): Promise<T extends "utf8" ? string : Buffer> => {
  436. filepath = path.isAbsolute(filepath)
  437. ? filepath
  438. : path.resolve(path.join(__dirname, "../", filepath));
  439. return readFile(filepath, { encoding }) as any;
  440. };
  441. static loadFile = async (filepath: string) => {
  442. const { base, ext } = path.parse(filepath);
  443. return new File([await API.readFile(filepath, null)], base, {
  444. type: getMimeType(ext),
  445. });
  446. };
  447. static drop = async (blob: Blob) => {
  448. const fileDropEvent = createEvent.drop(GlobalTestState.interactiveCanvas);
  449. const text = await new Promise<string>((resolve, reject) => {
  450. try {
  451. const reader = new FileReader();
  452. reader.onload = () => {
  453. resolve(reader.result as string);
  454. };
  455. reader.readAsText(blob);
  456. } catch (error: any) {
  457. reject(error);
  458. }
  459. });
  460. const files = [blob] as File[] & { item: (index: number) => File };
  461. files.item = (index: number) => files[index];
  462. Object.defineProperty(fileDropEvent, "dataTransfer", {
  463. value: {
  464. files,
  465. getData: (type: string) => {
  466. if (type === blob.type) {
  467. return text;
  468. }
  469. return "";
  470. },
  471. },
  472. });
  473. await fireEvent(GlobalTestState.interactiveCanvas, fileDropEvent);
  474. };
  475. static executeAction = (action: Action) => {
  476. act(() => {
  477. h.app.actionManager.executeAction(action);
  478. });
  479. };
  480. }