Skip to content

Commit

Permalink
Merge pull request #15 from ConnectAI-E/featuer_notion
Browse files Browse the repository at this point in the history
Featuer notion
  • Loading branch information
lloydzhou authored Nov 14, 2023
2 parents fccc221 + ebf0a11 commit 394cf14
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 2 deletions.
58 changes: 58 additions & 0 deletions server/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LarkWikiLoader,
LarkDocLoader,
YuqueDocLoader,
NotionDocLoader,
)


Expand Down Expand Up @@ -64,6 +65,26 @@ def embed_documents(fileUrl, fileType, fileName, collection_id, openai=False, un
version=0,
)
document_ids.append(document_id)

elif fileType in ['notion']:
# notion文件导入
collection = get_collection_by_id(None, collection_id)
user = get_user(collection.user_id)
extra = user.extra.to_dict()
notion = extra.get('notion', {})
loader = NotionDocLoader(fileUrl, **client)
doc = loader.load()
document_id = embedding_single_document(
doc, fileUrl, fileType,
doc.metadata.get('title'),
collection_id,
openai=openai,
uniqid=doc.metadata.get('uniqid'),
version=0
)
document_ids.append(document_id)


elif fileType in ['pdf', 'word', 'excel', 'markdown', 'ppt', 'txt']:
loader_class, loader_args = LOADER_MAPPING[fileType]
# 全是文件,需要下载,再加载
Expand Down Expand Up @@ -212,3 +233,40 @@ def sync_yuque(openai=False):

logging.info("updated document_ids %r", document_ids)

@celery.task()
def sync_notion(openai=False):
document_ids = []
response = Search(index="document").filter(
"term", type="notion"
).filter(
"term", status=0,
).extra(
from_=0, size=10000
).sort({"modified": {"order": "desc"}}).execute()
total = response.hits.total.value
logging.info("debug sync_notion %r", total)
for document in response:
try:
collection = get_collection_by_id(None, document.collection_id)
user = get_user(collection.user_id)
extra = user.extra.to_dict()
notion = extra.get('notion', {})
loader = NotionDocLoader(document.path, **notion)
# 没有版本号,先load一遍,再按时间判断是否重新向量化入库
doc = loader.load()
if doc.metadata.get('modified') > document.modified:
document_id = embedding_single_document(
doc, document.path, document.type,
doc.metadata.get('title'),
document.collection_id,
openai=openai,
uniqid=doc.metadata.get('uniqid'),
version=0, # 当前只有飞书文档需要更新版本
)
document_ids.append(document_id)
# 移除旧文档
purge_document_by_id(document.meta.id)
except Exception as e:
logging.error('error to sync_notion %r %r', document.path, e)

logging.info("updated document_ids %r", document_ids)
17 changes: 16 additions & 1 deletion server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from celery_app import embed_documents, get_status_by_id, embed_feishuwiki
from sse import ServerSentEvents
from tasks import LarkDocLoader, YuqueDocLoader, LarkWikiLoader
from tasks import LarkDocLoader, YuqueDocLoader, NotionDocLoader, LarkWikiLoader


class InternalError(Exception): pass
Expand Down Expand Up @@ -527,6 +527,21 @@ def api_embed_documents(collection_id):
'code': -1,
'msg': str(e)
})
elif fileType == 'notion':
# 如果是notion文档,同步尝试load一下,失败了就同步报错
try:
collection = get_collection_by_id(None, collection_id)
user = get_user(user_id)
extra = user.extra.to_dict()
notion = extra.get('notion', {})
loader = NotionDocLoader(fileUrl, **notion)
doc = loader.load()
except Exception as e:
app.logger.error(e)
return jsonify({
'code': -1,
'msg': str(e)
})
# isopenai=False
task = embed_documents.delay(fileUrl, fileType, fileName, collection_id, False, uniqid=uniqid)
return jsonify({
Expand Down
161 changes: 160 additions & 1 deletion server/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from urllib.parse import urlparse

import httpx
from functools import cached_property
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -68,8 +70,13 @@ def __call__(self, *args, **kwargs):
},
"sync_yuque": {
"task": "celery_app.sync_yuque",
"schedule": timedelta(seconds=3610), # 定时1hours执行一次,避免任务一起执行,占资源
"schedule": timedelta(seconds=3700), # 定时1hours执行一次,避免任务一起执行,占资源
"args": (False) # 函数传参的值
},
"sync_notion": {
"task": "celery_app.sync_notion",
"schedule": timedelta(seconds=3800), # 定时1hours执行一次,避免任务一起执行,占资源
"args": (False) # 函数传参的值
}
}

Expand Down Expand Up @@ -281,6 +288,7 @@ def file_info(self):
class YuqueDocLoader(object):

def __init__(self, fileUrl, **kwargs):
self.fileUrl = fileUrl
# https://www.yuque.com/yuque/developer/doc
self.fileUrl = fileUrl
temp = fileUrl.split('?')[0].split('/')
Expand Down Expand Up @@ -328,4 +336,155 @@ def load(self):
)
)

class NotionDocLoader(object):

def __init__(self, fileUrl, **kwargs):
# https://www.notion.so/b1-8beaa48d081e44e69000cd789726a151?pvs=4
# https://www.notion.so/b1-8beaa48d081e44e69000cd789726a151
self.fileUrl = fileUrl
self.page_id, self.title = self.extract_ids(fileUrl)[0], self.extract_ids(fileUrl)[1]
self.config = kwargs

# notion文档的标题在链接里面,id也需要加入分号分割,需要单独做一个操作
def extract_ids(self, link):
# 解析URL
parsed_url = urlparse(link)
# 从URL的路径中提取ID
path = parsed_url.path
parts = path.split('/')
ids = parts[-1]
# 获取后32位
id = ids[-32:-24] + '-' + ids[-24:-20] + '-' + ids[-20:-16] + '-' + ids[-16:-12] + '-' + ids[-12:]
title = ids[:-33]
return id, title

def retrieve_block_children(self):
"""
这里的 url 是接受处理过的id号拼接为的接口地址
"""
url = f"https://api.notion.com/v1/blocks/{self.page_id}/children"

# notion的版本可能会有更新的问题
headers = {
"Authorization": self.config.get('token'),
"Notion-Version": '2022-06-28'
}


blocks = []
cursor = None
while True:
params = {}
# if cursor:
# params["start_cursor"] = cursor
res = httpx.get(url, headers=headers, params=params).json()

blocks.extend(res.get("results", []))
has_more = res.get("has_more", False)
if not has_more:
break
cursor = res.get("next_cursor")
# print(json.dumps(blocks))
return blocks


# 获取到 block 中 paragraph.rich_text[*].plain_text 的富文本内容
def get_plain_text_from_rich_text(self, rich_text):
return "".join([t['plain_text'] for t in rich_text])

def get_text_from_block(self, block):
text = ""
# 这里加异常处理原因是,获取到的content中不一定有 block[block['type']]['rich_text']
block = block[0]
try:
if block[block['type']]['rich_text']:
text = self.get_plain_text_from_rich_text(block[block['type']]['rich_text'])
except:
block_type = block['type']
if block_type == "unsupported":
text = "[Unsupported block type]"

elif block_type == "bookmark":
text = block['bookmark']['url']

elif block_type == "child_database":
text = block['child_database']['title']

elif block_type == "child_page":
text = block['child_page']['title']

# elif block_type in ["embed", "video", "file", "image", "pdf"]:
# text = get_media_source_text(block)

elif block_type == "equation":
text = block['equation']['expression']

elif block_type == "link_preview":
text = block['link_preview']['url']

elif block_type == "synced_block":
if 'synced_from' in block['synced_block']:
synced_with_block = block['synced_block']['synced_from']
text = f"This block is synced with a block with the following ID: {synced_with_block[synced_with_block['type']]}"
else:
text = "Source sync block that another blocked is synced with."

elif block_type == "table":
text = f"Table width: {block['table']['table_width']}"

elif block_type == "table_of_contents":
text = f"ToC color: {block['table_of_contents']['color']}"

elif block_type in ["breadcrumb", "column_list", "divider"]:
text = "No text available"

else:
text = "[Needs case added]"

string_data = f"{block['type']}: {text}"
# 使用冒号分割字符串
key, value = string_data.split(':')
# 创建字典
my_dict = {key.strip(): value.strip()}

# 只返回正文的内容
if 'paragraph' in my_dict and my_dict['paragraph'] != '':
return my_dict['paragraph']
else:
return ''

# return f"{block['type']}: {text}"

def load(self):
url = f"https://api.notion.com/v1/blocks/{self.page_id}/children"

blocks, cursor = [], None
headers = {
"Authorization": self.config.get('token'),
"Notion-Version": '2022-06-28'
}
params = {}
if cursor:
params["start_cursor"] = cursor
res = httpx.get(url, headers=headers,params=params).json()
blocks.extend(res.get("results", []))

text = self.get_text_from_block(blocks)

if res['results'] == '':
# app.logger.error("error get content %r", res)
raise Exception('「企联 AI Notion助手」无该文档访问权限')
# raise Exception(f'error get content for document')
return Document(
page_content=text,
metadata=dict(
fileUrl=self.fileUrl,
id=self.extract_ids(self.fileUrl)[0],
title=self.extract_ids(self.fileUrl)[1],
# 唯一ID,用于区分
uniqid=f"{self.title}-{self.page_id}",
modified=datetime.fromisoformat(res['result']['last_edited_time'].split('.')[0]),
)
)


0 comments on commit 394cf14

Please sign in to comment.