344 lines
12 KiB
OCaml
344 lines
12 KiB
OCaml
open Parser
|
|
|
|
let default_db_path () =
|
|
let cache = try Sys.getenv "XDG_CACHE_HOME"
|
|
with Not_found -> Filename.concat (Sys.getenv "HOME") ".cache" in
|
|
Filename.concat cache "inshellah/completions.db"
|
|
|
|
let ensure_parent path =
|
|
let dir = Filename.dirname path in
|
|
let rec mkdir_p d =
|
|
if Sys.file_exists d then ()
|
|
else begin mkdir_p (Filename.dirname d); Unix.mkdir d 0o755 end in
|
|
mkdir_p dir
|
|
|
|
let init db_path =
|
|
ensure_parent db_path;
|
|
let db = Sqlite3.db_open db_path in
|
|
let exec sql =
|
|
match Sqlite3.exec db sql with
|
|
| Sqlite3.Rc.OK -> ()
|
|
| rc -> failwith (Printf.sprintf "sqlite: %s: %s" (Sqlite3.Rc.to_string rc) sql) in
|
|
exec "PRAGMA journal_mode=WAL";
|
|
exec "PRAGMA synchronous=NORMAL";
|
|
exec {|CREATE TABLE IF NOT EXISTS completions (
|
|
command TEXT PRIMARY KEY,
|
|
data TEXT NOT NULL,
|
|
source TEXT,
|
|
updated_at INTEGER NOT NULL
|
|
)|};
|
|
db
|
|
|
|
let close db = ignore (Sqlite3.db_close db)
|
|
|
|
(* --- JSON serialization of help_result --- *)
|
|
|
|
let escape_json s =
|
|
let buf = Buffer.create (String.length s + 4) in
|
|
String.iter (fun c -> match c with
|
|
| '"' -> Buffer.add_string buf "\\\""
|
|
| '\\' -> Buffer.add_string buf "\\\\"
|
|
| '\n' -> Buffer.add_string buf "\\n"
|
|
| '\t' -> Buffer.add_string buf "\\t"
|
|
| '\r' -> Buffer.add_string buf "\\r"
|
|
| c when Char.code c < 0x20 ->
|
|
Buffer.add_string buf (Printf.sprintf "\\u%04x" (Char.code c))
|
|
| c -> Buffer.add_char buf c
|
|
) s;
|
|
Buffer.contents buf
|
|
|
|
let json_string s = Printf.sprintf "\"%s\"" (escape_json s)
|
|
let json_null = "null"
|
|
|
|
let json_switch_of = function
|
|
| Short c -> Printf.sprintf "{\"type\":\"short\",\"char\":%s}" (json_string (String.make 1 c))
|
|
| Long l -> Printf.sprintf "{\"type\":\"long\",\"name\":%s}" (json_string l)
|
|
| Both (c, l) ->
|
|
Printf.sprintf "{\"type\":\"both\",\"char\":%s,\"name\":%s}"
|
|
(json_string (String.make 1 c)) (json_string l)
|
|
|
|
let json_param_of = function
|
|
| None -> json_null
|
|
| Some (Mandatory p) ->
|
|
Printf.sprintf "{\"kind\":\"mandatory\",\"name\":%s}" (json_string p)
|
|
| Some (Optional p) ->
|
|
Printf.sprintf "{\"kind\":\"optional\",\"name\":%s}" (json_string p)
|
|
|
|
let json_entry_of e =
|
|
Printf.sprintf "{\"switch\":%s,\"param\":%s,\"desc\":%s}"
|
|
(json_switch_of e.switch) (json_param_of e.param) (json_string e.desc)
|
|
|
|
let json_subcommand_of sc =
|
|
Printf.sprintf "{\"name\":%s,\"desc\":%s}" (json_string sc.name) (json_string sc.desc)
|
|
|
|
let json_positional_of p =
|
|
Printf.sprintf "{\"name\":%s,\"optional\":%b,\"variadic\":%b}"
|
|
(json_string p.pos_name) p.optional p.variadic
|
|
|
|
let json_list f items =
|
|
"[" ^ String.concat "," (List.map f items) ^ "]"
|
|
|
|
let json_of_help_result r =
|
|
Printf.sprintf "{\"entries\":%s,\"subcommands\":%s,\"positionals\":%s}"
|
|
(json_list json_entry_of r.entries)
|
|
(json_list json_subcommand_of r.subcommands)
|
|
(json_list json_positional_of r.positionals)
|
|
|
|
(* --- JSON deserialization --- *)
|
|
|
|
(* Minimal JSON parser — just enough for our own output *)
|
|
|
|
type json =
|
|
| Jnull
|
|
| Jbool of bool
|
|
| Jstring of string
|
|
| Jarray of json list
|
|
| Jobject of (string * json) list
|
|
|
|
let json_get key = function
|
|
| Jobject pairs -> (try List.assoc key pairs with Not_found -> Jnull)
|
|
| _ -> Jnull
|
|
|
|
let json_to_string = function Jstring s -> s | _ -> ""
|
|
let json_to_bool = function Jbool b -> b | _ -> false
|
|
let json_to_list = function Jarray l -> l | _ -> []
|
|
|
|
exception Json_error of string
|
|
|
|
let parse_json s =
|
|
let len = String.length s in
|
|
let pos = ref 0 in
|
|
let peek () = if !pos < len then s.[!pos] else '\x00' in
|
|
let advance () = incr pos in
|
|
let skip_ws () =
|
|
while !pos < len && (s.[!pos] = ' ' || s.[!pos] = '\t'
|
|
|| s.[!pos] = '\n' || s.[!pos] = '\r') do
|
|
advance ()
|
|
done in
|
|
let expect c =
|
|
skip_ws ();
|
|
if peek () <> c then
|
|
raise (Json_error (Printf.sprintf "expected '%c' at %d" c !pos));
|
|
advance () in
|
|
let rec parse_value () =
|
|
skip_ws ();
|
|
match peek () with
|
|
| '"' -> Jstring (parse_string ())
|
|
| '{' -> parse_object ()
|
|
| '[' -> parse_array ()
|
|
| 'n' -> advance (); advance (); advance (); advance (); Jnull
|
|
| 't' -> advance (); advance (); advance (); advance (); Jbool true
|
|
| 'f' ->
|
|
advance (); advance (); advance (); advance (); advance (); Jbool false
|
|
| c -> raise (Json_error (Printf.sprintf "unexpected '%c' at %d" c !pos))
|
|
and parse_string () =
|
|
expect '"';
|
|
let buf = Buffer.create 32 in
|
|
while peek () <> '"' do
|
|
if peek () = '\\' then begin
|
|
advance ();
|
|
(match peek () with
|
|
| '"' -> Buffer.add_char buf '"'
|
|
| '\\' -> Buffer.add_char buf '\\'
|
|
| 'n' -> Buffer.add_char buf '\n'
|
|
| 't' -> Buffer.add_char buf '\t'
|
|
| 'r' -> Buffer.add_char buf '\r'
|
|
| 'u' ->
|
|
advance ();
|
|
let hex = String.sub s !pos 4 in
|
|
pos := !pos + 3;
|
|
let code = int_of_string ("0x" ^ hex) in
|
|
if code < 128 then Buffer.add_char buf (Char.chr code)
|
|
else begin
|
|
(* UTF-8 encode *)
|
|
if code < 0x800 then begin
|
|
Buffer.add_char buf (Char.chr (0xc0 lor (code lsr 6)));
|
|
Buffer.add_char buf (Char.chr (0x80 lor (code land 0x3f)))
|
|
end else begin
|
|
Buffer.add_char buf (Char.chr (0xe0 lor (code lsr 12)));
|
|
Buffer.add_char buf (Char.chr (0x80 lor ((code lsr 6) land 0x3f)));
|
|
Buffer.add_char buf (Char.chr (0x80 lor (code land 0x3f)))
|
|
end
|
|
end
|
|
| c -> Buffer.add_char buf c);
|
|
advance ()
|
|
end else begin
|
|
Buffer.add_char buf (peek ());
|
|
advance ()
|
|
end
|
|
done;
|
|
advance (); (* closing quote *)
|
|
Buffer.contents buf
|
|
and parse_object () =
|
|
expect '{';
|
|
skip_ws ();
|
|
if peek () = '}' then (advance (); Jobject [])
|
|
else begin
|
|
let pairs = ref [] in
|
|
let cont = ref true in
|
|
while !cont do
|
|
skip_ws ();
|
|
let key = parse_string () in
|
|
expect ':';
|
|
let value = parse_value () in
|
|
pairs := (key, value) :: !pairs;
|
|
skip_ws ();
|
|
if peek () = ',' then advance ()
|
|
else cont := false
|
|
done;
|
|
expect '}';
|
|
Jobject (List.rev !pairs)
|
|
end
|
|
and parse_array () =
|
|
expect '[';
|
|
skip_ws ();
|
|
if peek () = ']' then (advance (); Jarray [])
|
|
else begin
|
|
let items = ref [] in
|
|
let cont = ref true in
|
|
while !cont do
|
|
let v = parse_value () in
|
|
items := v :: !items;
|
|
skip_ws ();
|
|
if peek () = ',' then advance ()
|
|
else cont := false
|
|
done;
|
|
expect ']';
|
|
Jarray (List.rev !items)
|
|
end
|
|
in
|
|
parse_value ()
|
|
|
|
let switch_of_json j =
|
|
match json_to_string (json_get "type" j) with
|
|
| "short" ->
|
|
let c = json_to_string (json_get "char" j) in
|
|
Short (if String.length c > 0 then c.[0] else '?')
|
|
| "long" -> Long (json_to_string (json_get "name" j))
|
|
| "both" ->
|
|
let c = json_to_string (json_get "char" j) in
|
|
Both ((if String.length c > 0 then c.[0] else '?'),
|
|
json_to_string (json_get "name" j))
|
|
| _ -> Long "?"
|
|
|
|
let param_of_json = function
|
|
| Jnull -> None
|
|
| j ->
|
|
let name = json_to_string (json_get "name" j) in
|
|
(match json_to_string (json_get "kind" j) with
|
|
| "mandatory" -> Some (Mandatory name)
|
|
| "optional" -> Some (Optional name)
|
|
| _ -> None)
|
|
|
|
let entry_of_json j =
|
|
{ switch = switch_of_json (json_get "switch" j);
|
|
param = param_of_json (json_get "param" j);
|
|
desc = json_to_string (json_get "desc" j) }
|
|
|
|
let subcommand_of_json j =
|
|
{ name = json_to_string (json_get "name" j);
|
|
desc = json_to_string (json_get "desc" j) }
|
|
|
|
let positional_of_json j =
|
|
{ pos_name = json_to_string (json_get "name" j);
|
|
optional = json_to_bool (json_get "optional" j);
|
|
variadic = json_to_bool (json_get "variadic" j) }
|
|
|
|
let help_result_of_json j =
|
|
{ entries = List.map entry_of_json (json_to_list (json_get "entries" j));
|
|
subcommands = List.map subcommand_of_json (json_to_list (json_get "subcommands" j));
|
|
positionals = List.map positional_of_json (json_to_list (json_get "positionals" j)) }
|
|
|
|
(* --- Database operations --- *)
|
|
|
|
let upsert db ?(source="help") command result =
|
|
let json = json_of_help_result result in
|
|
let now = int_of_float (Unix.gettimeofday ()) in
|
|
let stmt = Sqlite3.prepare db
|
|
"INSERT INTO completions (command, data, source, updated_at) VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(command) DO UPDATE SET data=excluded.data, source=excluded.source, updated_at=excluded.updated_at" in
|
|
ignore (Sqlite3.bind stmt 1 (Sqlite3.Data.TEXT command));
|
|
ignore (Sqlite3.bind stmt 2 (Sqlite3.Data.TEXT json));
|
|
ignore (Sqlite3.bind stmt 3 (Sqlite3.Data.TEXT source));
|
|
ignore (Sqlite3.bind stmt 4 (Sqlite3.Data.INT (Int64.of_int now)));
|
|
(match Sqlite3.step stmt with
|
|
| Sqlite3.Rc.DONE -> ()
|
|
| rc -> failwith (Printf.sprintf "upsert %s: %s" command (Sqlite3.Rc.to_string rc)));
|
|
ignore (Sqlite3.finalize stmt)
|
|
|
|
let upsert_raw db ?(source="native") command data =
|
|
let now = int_of_float (Unix.gettimeofday ()) in
|
|
let stmt = Sqlite3.prepare db
|
|
"INSERT INTO completions (command, data, source, updated_at) VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(command) DO UPDATE SET data=excluded.data, source=excluded.source, updated_at=excluded.updated_at" in
|
|
ignore (Sqlite3.bind stmt 1 (Sqlite3.Data.TEXT command));
|
|
ignore (Sqlite3.bind stmt 2 (Sqlite3.Data.TEXT data));
|
|
ignore (Sqlite3.bind stmt 3 (Sqlite3.Data.TEXT source));
|
|
ignore (Sqlite3.bind stmt 4 (Sqlite3.Data.INT (Int64.of_int now)));
|
|
(match Sqlite3.step stmt with
|
|
| Sqlite3.Rc.DONE -> ()
|
|
| rc -> failwith (Printf.sprintf "upsert_raw %s: %s" command (Sqlite3.Rc.to_string rc)));
|
|
ignore (Sqlite3.finalize stmt)
|
|
|
|
let lookup db command =
|
|
let stmt = Sqlite3.prepare db
|
|
"SELECT data, source FROM completions WHERE command = ?" in
|
|
ignore (Sqlite3.bind stmt 1 (Sqlite3.Data.TEXT command));
|
|
let result = match Sqlite3.step stmt with
|
|
| Sqlite3.Rc.ROW ->
|
|
let data = Sqlite3.column_text stmt 0 in
|
|
let source = Sqlite3.column_text stmt 1 in
|
|
Some (data, source)
|
|
| _ -> None in
|
|
ignore (Sqlite3.finalize stmt);
|
|
result
|
|
|
|
let lookup_result db command =
|
|
match lookup db command with
|
|
| None -> None
|
|
| Some (data, _source) ->
|
|
(try Some (help_result_of_json (parse_json data))
|
|
with _ -> None)
|
|
|
|
let has_command db command =
|
|
let stmt = Sqlite3.prepare db
|
|
"SELECT 1 FROM completions WHERE command = ?" in
|
|
ignore (Sqlite3.bind stmt 1 (Sqlite3.Data.TEXT command));
|
|
let found = Sqlite3.step stmt = Sqlite3.Rc.ROW in
|
|
ignore (Sqlite3.finalize stmt);
|
|
found
|
|
|
|
let all_commands db =
|
|
let stmt = Sqlite3.prepare db "SELECT command FROM completions ORDER BY command" in
|
|
let results = ref [] in
|
|
while Sqlite3.step stmt = Sqlite3.Rc.ROW do
|
|
results := Sqlite3.column_text stmt 0 :: !results
|
|
done;
|
|
ignore (Sqlite3.finalize stmt);
|
|
List.rev !results
|
|
|
|
let delete db command =
|
|
let stmt = Sqlite3.prepare db "DELETE FROM completions WHERE command = ?" in
|
|
ignore (Sqlite3.bind stmt 1 (Sqlite3.Data.TEXT command));
|
|
ignore (Sqlite3.step stmt);
|
|
ignore (Sqlite3.finalize stmt)
|
|
|
|
let begin_transaction db =
|
|
match Sqlite3.exec db "BEGIN IMMEDIATE" with
|
|
| Sqlite3.Rc.OK -> () | _ -> ()
|
|
|
|
let commit db =
|
|
match Sqlite3.exec db "COMMIT" with
|
|
| Sqlite3.Rc.OK -> () | _ -> ()
|
|
|
|
let stats db =
|
|
let stmt = Sqlite3.prepare db
|
|
"SELECT COUNT(*), COUNT(DISTINCT source) FROM completions" in
|
|
let result = match Sqlite3.step stmt with
|
|
| Sqlite3.Rc.ROW ->
|
|
let count = Sqlite3.column_int stmt 0 in
|
|
let sources = Sqlite3.column_int stmt 1 in
|
|
(count, sources)
|
|
| _ -> (0, 0) in
|
|
ignore (Sqlite3.finalize stmt);
|
|
result
|