1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
| import sqlite3 from tree_sitter import Language, Parser from typing import NamedTuple, List, Optional from concurrent.futures import ThreadPoolExecutor, as_completed from tenacity import ( retry, stop_after_attempt, wait_random_exponential, stop_after_delay, ) import openai
openai.api_key = ""
class COMMENT(NamedTuple): start_byte: int end_byte: int
Language.build_library( 'build/my-languages.so',
[ 'vendor/tree-sitter-proto', ] )
with sqlite3.connect('translations.db') as conn: c = conn.cursor()
c.execute(''' CREATE TABLE IF NOT EXISTS translations (language text, text text, translation text) ''') conn.commit()
def get_translation_from_cache(language: str, text: str) -> Optional[str]: with sqlite3.connect('translations.db') as conn: c = conn.cursor() c.execute("SELECT translation FROM translations WHERE language=? AND text=?", (language, text)) result = c.fetchone() return result[0] if result else None
def store_translation_in_cache(language: str, text: str, translation: str): with sqlite3.connect('translations.db') as conn: c = conn.cursor() c.execute("INSERT INTO translations VALUES (?, ?, ?)", (language, text, translation)) conn.commit()
@retry(wait=wait_random_exponential(min=1, max=60), stop=(stop_after_delay(60) | stop_after_attempt(5))) def translate(language: str, text: str) -> str: translation = get_translation_from_cache(language, text) if translation: return translation response = openai.ChatCompletion.create( model="gpt-3.5-turbo", request_timeout=200 * 60, temperature=0, messages=[ {"role": "system", "content": f"将以下{language}代码注释翻译到中文, 要求1: 翻译后单行文本最好不要超过60个字符 要求2: 保持注释格式 要求3: 只做翻译,不要解释内容"}, {"role": "user", "content": f'{text}'} ] ) translation = response["choices"][0]["message"]["content"]
store_translation_in_cache(language, text, translation)
return translation
def bulk_translate(language: str, texts: List[str]): print("start translate") with ThreadPoolExecutor(max_workers=8) as executor: futures = {executor.submit(translate, language, text): text for text in texts}
result = {futures[future]: future.result() for future in as_completed(futures)} return [result[i] for i in texts]
def is_continue(text: bytes, start: int, end: int) -> bool: return text[start:end].strip() == b''
def merge_intervals(intervals: List[COMMENT], text: bytes, func=is_continue) -> List[COMMENT]: output: List[COMMENT] = [] for interval in intervals: if not output or not func(text, output[-1].end_byte, interval.start_byte): output.append(interval) else: output[-1] = COMMENT(start_byte=output[-1].start_byte, end_byte=interval.end_byte) return output
def trans_comments(tree, source_code: bytes) -> bytes: comments: List[COMMENT] = []
def traverse(node): if node.type == 'comment': comments.append(node) else: for child in node.children: traverse(child)
traverse(tree.root_node) comments = merge_intervals(comments, source_code) comments.reverse() translate_comments = bulk_translate("proto", [source_code[comment.start_byte:comment.end_byte].decode("utf-8") for comment in comments])
for index, comment in enumerate(comments): start_byte = comment.start_byte end_byte = comment.end_byte raw_comment = source_code[start_byte:end_byte] target_comment = translate_comments[index].encode("utf-8") t = target_comment source_code = source_code[:start_byte] + t + source_code[end_byte:]
return source_code
PROTO_LANGUAGE = Language('build/my-languages.so', 'proto') parser = Parser() parser.set_language(PROTO_LANGUAGE)
source_code = open('scip.proto', 'rb').read() tree = parser.parse(source_code) source_code = trans_comments(tree, source_code)
with open('scip_trans.proto', 'wb') as output_file: output_file.write(source_code)
|