Typescript Compiler API で定義元の AST ノードを取得する

備忘録。TypeScript の Compiler API を使ってソースコードのパースから特定の型・関数の定義元の AST ノードへジャンプする方法について。

TypeScript のバージョンは 4.1.3。

AST ノードの基礎

ノードの型の判別

全ての AST ノードはインタフェース Node を実装している。ノードの種類毎に型が異なるが、いずれのノード型も kind プロパティに列挙型 SyntaxKind を持っているので、これを調べることになる。

例えば関数呼び出し式型 CallExpression について考える。あるノードが CallExpression かどうかを確かめるためには、kindSyntaxKind.CallExpression であるかをチェックする。

const node: ts.Node = ...
if (node.kind === ts.SyntaxKind.CallExpression) {
  // 関数呼び出し式に対する処理
  const callExpr = node as ts.CallExpression
  callExpr.argument.forEach(...)
}

ありがたいことに主要なノード型の Type Guard を isFooBar という名前で提供している。例えば関数呼び出し式の Type Guard は isCallExpression() である。

function isCallExpression(node: Node): node is CallExpression

これを使えば上記コードからキャストを消すことができる。

const node: ts.Node = ...
if (ts.isCallExpression(node)) {
  // Type Guard を通しているので、直接 CallExpression として扱える
  node.arguments.forEach(...)
}

ノードの探査

ノードの探査は forEachChild() を使う。引数には探査対象のノードとコールバックを渡す。コールバックが真となる値を返すと、その時点で探査が終了してその値を返すことに注意。1回の呼び出しでは直下の子ノードのみを探査するので、必要に応じて再帰的に呼び出すことになる。

例えば AST 内で関数呼び出し式を1つ見つけるコードは以下の通り。

function callback(node: ts.Node): ts.CallExpression | undefined {
  if (isCallExpression(node)) {
    return node
  }
  return ts.forEachChild(node, callback)
}

const callExpression = ts.forEachChild(sourceFile, callback)

解析例

以下の2ファイルがあるとする。

src/api.ts

import { SearchParams, SearchResponse } from "./model"

export function search(params: SearchParams): SearchResponse {
  throw new Error("not implemented")
}

src/model.ts

export type SearchParams = {
  name: string
}

export interface SearchResponse {
  name: string
}

export interface SearchResponse {
  age: number
}

api.ts 内の search の宣言から、search の実装及び引数・戻り値の型定義の情報を出力することを目指す。

main.ts 全体の AST を取得する。

まずは createProgram() にエントリポイントを渡すことで、プログラム全体の AST を取得する。トップレベルノードの Programmain.ts だけでなく、import によって参照される全てのファイルも含んでいる。特定のファイルを表すノードは Program.getSourceFile() などで取得できる。

import * as ts from "typescript"
import * as path from "path"

function assertIsNotUndefined<T>(x: T | undefined): asserts x is T {
  if (x === undefined) {
    throw new TypeError("unexpected undefined")
  }
}

const filename = "src/main.ts"
const compilerOptions: ts.CompilerOptions = {}
const program: ts.Program = ts.createProgram([filename], compilerOptions)
const sourceFile: ts.SourceFile | undefined = program.getSourceFile(filename)

assertIsNotUndefined(sourceFile)

search を探す

sourceFile から以下の条件で search を宣言しているノード)を探す。

  • ノードの種類が FunctionDeclaration
  • 関数名が search
function findSearchDeclaration(
  node: ts.Node
): ts.FunctionDeclaration | undefined {
  if (
    ts.isFunctionDeclaration(node) &&
    node.name !== undefined &&
    ts.idText(node.name) === "search"
  ) {
    return node
  }

  return ts.forEachChild(node, findSearchDeclaration)
}

const searchDecl = ts.forEachChild(sourceFile, findSearchDeclaration)
assertIsNotUndefined(searchDecl)

search の実装を出力する

AST を出力するためには、まず createPrinter()Printer オブジェクトを作る。 Printer.printNode() でノードのコード文字列を取得することができる。

const printer: ts.Printer = ts.createPrinter()
console.log(printer.printNode(ts.EmitHint.Unspecified, searchDecl, sourceFile))

戻り値の型の定義を取得する

関数定義の戻り値の型宣言のノードは FunctionDeclaration.type から取得できる。

型ノードからその宣言を取得するには以下の手順を踏む:

  1. 対象プログラムの型検査器にあたる TypeChecker を取得(Program.getTypeChecker()
  2. 型ノードが表す型情報そのものを表すオブジェクト Type を取得(TypeChecker.getTypeAtLocation()
  3. 型オブジェクトからシンボルオブジェクト Symbol を取得(Type.getSymbol()
  4. シンボルオブジェクトからそのシンボルを宣言した AST 一覧を取得(Symbol.getDeclarations()
const checker: ts.TypeChecker = program.getTypeChecker()
assertIsNotUndefined(searchDecl.type)
const returnType: ts.Type | undefined = checker.getTypeAtLocation(
  searchDecl.type
)
assertIsNotUndefined(returnType)

const symbol: ts.Symbol | undefined = returnType.getSymbol()
assertIsNotUndefined(symbol)

const typeDecls: ts.Declaration[] | undefined = symbol.getDeclarations()
assertIsNotUndefined(typeDecls)

シンボルを宣言した AST が複数あるのは関数のオーバーロードinterface が複数回にわたって宣言される可能性があるため。

戻り値型の宣言の出力

Declaration[] の出力は forEach などで行ってもいいが、createNodeArray() を使って NodeArray 型に変換することで、Printer.printList() に任せることもできる。以下は各宣言を改行で join して出力する例。

console.log("declarations of return type of search")
console.log(
  printer.printList(
    ts.ListFormat.MultiLine | ts.ListFormat.NoTrailingNewLine,
    ts.factory.createNodeArray(typeDecls),
    sourceFile
  )
)

まとめ

コードの全体像は以下の通り

import * as ts from "typescript"

function assertIsNotUndefined<T>(x: T | undefined): asserts x is T {
  if (x === undefined) {
    throw new TypeError("unexpected undefined")
  }
}

const filename = "src/api.ts"
const compilerOptions: ts.CompilerOptions = {}
const program: ts.Program = ts.createProgram([filename], compilerOptions)
const sourceFile: ts.SourceFile | undefined = program.getSourceFile(filename)
assertIsNotUndefined(sourceFile)

function findSearchDeclaration(
  node: ts.Node
): ts.FunctionDeclaration | undefined {
  if (
    ts.isFunctionDeclaration(node) &&
    node.name !== undefined &&
    ts.idText(node.name) === "search"
  ) {
    return node
  }

  return ts.forEachChild(node, findSearchDeclaration)
}

const searchDecl = ts.forEachChild(sourceFile, findSearchDeclaration)
assertIsNotUndefined(searchDecl)

const printer: ts.Printer = ts.createPrinter()
console.log("declaration of search")
console.log(printer.printNode(ts.EmitHint.Unspecified, searchDecl, sourceFile))

const checker: ts.TypeChecker = program.getTypeChecker()
assertIsNotUndefined(searchDecl.type)
const returnType: ts.Type | undefined = checker.getTypeAtLocation(
  searchDecl.type
)
assertIsNotUndefined(returnType)

const symbol: ts.Symbol | undefined = returnType.getSymbol()
assertIsNotUndefined(symbol)

const typeDecls: ts.Declaration[] | undefined = symbol.getDeclarations()
assertIsNotUndefined(typeDecls)
console.log("declarations of return type of search")
console.log(
  printer.printList(
    ts.ListFormat.MultiLine | ts.ListFormat.NoTrailingNewLine,
    ts.factory.createNodeArray(typeDecls),
    sourceFile
  )
)